* docs: comprehensive documentation improvements for humans and agents New human docs: - grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling) - ebft.qmd: EBFT guide (structured/strided modes, feature extraction) - choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO - vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync) - training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics New agent docs: - AGENTS_SFT.md: agent reference for supervised fine-tuning - AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO) Updated existing docs: - rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides - getting-started.qmd: reorganized Next Steps with links to new guides - debugging.qmd: link to training stability guide - _quarto.yml: added new pages to sidebar navigation Removed: - bak.agents.md: stale backup that confused agents * docs: trim duplicated generic config from AGENTS_DPO.md Remove boilerplate training params (optimizer, gradient_checkpointing, flash_attention, etc.) from each method template. These are not preference-learning-specific and are already covered in AGENTS_SFT.md. Config templates now show only method-specific fields with a reference to AGENTS_SFT.md for the rest. * docs: deduplicate across new doc pages - grpo.qmd: collapse vLLM setup section to brief config + link to vllm_serving.qmd; collapse IS correction to essentials + link; replace full monitoring tables with summary + link to training_stability.qmd - vllm_serving.qmd: remove duplicated async/IS config reference tables (already in grpo.qmd config reference); replace full example config with link to grpo.qmd quick start - ebft.qmd: trim generic training params in quick start config * fix: train scripts * feat: split files into cleaner parts * fix: cleanup pretraining docs --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
557 lines
24 KiB
Plaintext
557 lines
24 KiB
Plaintext
---
|
|
title: "EBFT Training"
|
|
description: "Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions."
|
|
order: 9
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-expand: 2
|
|
toc-depth: 4
|
|
---
|
|
|
|
## Overview
|
|
|
|
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the **internal feature representations** of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
|
|
|
|
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
|
|
|
### How EBFT Differs from Other RL Methods
|
|
|
|
| Method | Reward Signal | Requires | Best For |
|
|
|--------|--------------|----------|----------|
|
|
| **GRPO** | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
|
|
| **DPO** | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
|
|
| **EBFT** | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
|
|
|
|
EBFT's key advantage is that it needs only ground-truth completions -- no reward engineering, no preference annotation, and no reward model training. The model's own internal representations serve as the reward signal. This makes it particularly effective for:
|
|
|
|
- Code generation (match features of known-good solutions)
|
|
- Instruction following with reference outputs
|
|
- Continual pretraining on unstructured text (strided mode)
|
|
- Multi-turn dialogue with reference conversations
|
|
|
|
### Reward Formulation
|
|
|
|
The EBFT reward for each generated completion is:
|
|
|
|
```
|
|
reward = alignment_coef * cosine_similarity(gen_features, gt_features)
|
|
- diversity_coef * mean_pairwise_similarity(gen_features)
|
|
```
|
|
|
|
- **Alignment**: How closely the generated output's internal representations match the ground truth. Higher is better.
|
|
- **Diversity**: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
|
|
- **CFM loss** (Cross-Feature Matching): Tracks `||mean(gen_features) - gt_features||^2` as a diagnostic. This is the quantity that EBFT ultimately minimizes.
|
|
|
|
## Modes
|
|
|
|
EBFT supports three operational modes, each suited to different use cases.
|
|
|
|
### Structured Mode (Sync)
|
|
|
|
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
|
|
|
|
```
|
|
GPU 0: vLLM Server (generates completions, receives weight syncs)
|
|
GPU 1: Trainer (feature extraction, reward computation, GRPO training)
|
|
```
|
|
|
|
**When to use**: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
|
|
|
|
### Structured Mode (Async)
|
|
|
|
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
|
|
|
|
**When to use**: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by `vllm_sync_interval`).
|
|
|
|
### Strided Mode
|
|
|
|
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
|
|
|
|
```
|
|
Single GPU: Base model + LoRA adapter
|
|
- Strided block-parallel generation (flex_attention)
|
|
- Feature extraction via disable_adapter()
|
|
- No vLLM needed
|
|
```
|
|
|
|
**When to use**: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
|
|
|
|
## Quick Start
|
|
|
|
### Structured Mode
|
|
|
|
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
|
|
|
|
**Step 1**: Create a config file `ebft_quickstart.yaml`:
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2-0.5B-Instruct
|
|
|
|
rl: ebft
|
|
|
|
ebft:
|
|
feature_layers: [0.25, 0.5, 0.75]
|
|
embed_method: last_token
|
|
alignment_coef: 1.0
|
|
diversity_coef: 1.0
|
|
|
|
trl:
|
|
num_generations: 4
|
|
max_completion_length: 256
|
|
temperature: 0.7
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_lora_sync: true
|
|
vllm_sync_interval: 3
|
|
use_data_producer: true
|
|
async_prefetch: false
|
|
scale_rewards: true
|
|
loss_type: grpo
|
|
|
|
vllm:
|
|
gpu_memory_utilization: 0.5
|
|
max_model_len: 1024
|
|
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_opencode.transform
|
|
split: train[:500]
|
|
|
|
# Standard training settings (see getting-started.qmd for details)
|
|
adapter: lora
|
|
lora_r: 16
|
|
lora_alpha: 32
|
|
lora_target_linear: true
|
|
sequence_len: 1024
|
|
micro_batch_size: 2
|
|
gradient_accumulation_steps: 4
|
|
max_steps: 20
|
|
learning_rate: 5.0e-6
|
|
bf16: auto
|
|
flash_attention: true
|
|
gradient_checkpointing: true
|
|
output_dir: ./outputs/ebft-quickstart
|
|
```
|
|
|
|
**Step 2**: Start vLLM on GPU 0:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml
|
|
```
|
|
|
|
**Step 3**: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml
|
|
```
|
|
|
|
::: {.callout-important}
|
|
The `micro_batch_size` must be divisible by `num_generations`. For example, with `num_generations: 4`, valid values are 4, 8, 12, etc.
|
|
:::
|
|
|
|
### Dataset Format
|
|
|
|
Structured mode datasets must produce two fields after the transform:
|
|
|
|
- `prompt`: Either a string or a list of chat messages (`[{"role": "user", "content": "..."}]`)
|
|
- `ground_truth`: A string containing the reference completion
|
|
|
|
Example raw dataset row:
|
|
|
|
```json
|
|
{
|
|
"input": "Write a function to compute fibonacci numbers.",
|
|
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
|
|
}
|
|
```
|
|
|
|
The `ebft_opencode.transform` converts this to the required `{prompt, ground_truth}` format automatically.
|
|
|
|
## Feature Extraction
|
|
|
|
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
|
|
|
|
### Feature Layers
|
|
|
|
The `feature_layers` parameter specifies which layers to extract, as fractions of total model depth:
|
|
|
|
```yaml
|
|
ebft:
|
|
feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depth
|
|
```
|
|
|
|
For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size `num_layers * hidden_dim`.
|
|
|
|
::: {.callout-tip}
|
|
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default `[0.25, 0.5, 0.75]` works well across model sizes.
|
|
:::
|
|
|
|
### Embed Methods
|
|
|
|
The `embed_method` controls how per-token hidden states are pooled into a single vector per sequence:
|
|
|
|
| Method | Description | Output Shape | Notes |
|
|
|--------|-------------|-------------|-------|
|
|
| `last_token` | Hidden state at the last non-padding token | `(B, D)` | Default. Good for autoregressive models where the last token summarizes the sequence. |
|
|
| `mean_pooling` | Mean of all non-padding token states | `(B, D)` | Considers the entire sequence equally. |
|
|
| `completion_mean` | Mean over completion tokens only (excludes prompt) | `(B, D)` | Focuses reward signal on generated content. Requires prompt length information. |
|
|
| `concat` | Concatenation of states at 25%, 50%, 75% positions | `(B, 3*D)` | Captures positional structure. Higher dimensional. |
|
|
|
|
```yaml
|
|
ebft:
|
|
embed_method: completion_mean # Focus on completion features
|
|
```
|
|
|
|
### SVD Whitening
|
|
|
|
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
|
|
|
|
```yaml
|
|
ebft:
|
|
use_whitening: true
|
|
```
|
|
|
|
When whitening is enabled, the reward computation applies a whitening matrix `W = U @ diag(1/S) @ U^T` derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
|
|
|
|
::: {.callout-note}
|
|
Singular values scale with `sqrt(batch_size)`, so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (`n_samples_per_prompt` or `num_generations`) is fixed during training.
|
|
:::
|
|
|
|
### Alignment and Diversity Coefficients
|
|
|
|
The two reward components are weighted by coefficients:
|
|
|
|
```yaml
|
|
ebft:
|
|
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
|
|
diversity_coef: 1.0 # Weight for pairwise similarity penalty
|
|
```
|
|
|
|
Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
|
|
|
|
```
|
|
reward_j = 2 * alignment_coef * cos(gen_j, gt)
|
|
- 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')
|
|
```
|
|
|
|
Setting `diversity_coef: 0.0` disables the diversity penalty entirely, which may be appropriate when `num_generations` is small (e.g., 2).
|
|
|
|
## Strided Mode
|
|
|
|
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places **anchor points** at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
|
|
|
|
### How Block-Parallel Generation Works
|
|
|
|
Given a document of length `S` tokens:
|
|
|
|
1. **Anchor placement**: Starting at position `anchor_offset`, place anchors every `stride` tokens. Each anchor defines a block.
|
|
2. **Context window**: Each block sees `context_length` tokens of preceding context from the original document.
|
|
3. **Generation**: At each anchor, generate `generate_max_len` tokens autoregressively, conditioned only on the context window.
|
|
4. **Parallelism**: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
|
|
|
|
```
|
|
Document: [tok0, tok1, ..., tok_S]
|
|
| | |
|
|
anchor_0 anchor_1 anchor_2
|
|
| | |
|
|
[ctx][gen] [ctx][gen] [ctx][gen]
|
|
```
|
|
|
|
The attention mask ensures:
|
|
|
|
- Prompt tokens use standard causal attention
|
|
- Each generated block attends to its own context window and its own preceding generated tokens
|
|
- Blocks do not attend to each other's generated tokens
|
|
|
|
When `flex_attention` is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
|
|
|
|
### Strided Mode Configuration
|
|
|
|
```yaml
|
|
base_model: meta-llama/Llama-3.2-1B
|
|
rl: ebft
|
|
|
|
ebft:
|
|
mode: strided
|
|
stride: 8 # Tokens between anchor points
|
|
context_length: 8 # Context window per block
|
|
generate_max_len: 8 # Tokens to generate per block
|
|
n_samples_per_prompt: 4 # Independent rollouts per document
|
|
temperature: 0.6
|
|
feature_layers: [0.25, 0.5, 0.75]
|
|
embed_method: last_token
|
|
use_whitening: true
|
|
alignment_coef: 1.0
|
|
diversity_coef: 1.0
|
|
rl_coef: 1.0 # RL policy gradient loss weight
|
|
ce_coef: 0.03 # Cross-entropy loss on GT tokens
|
|
advantage_estimator: rloo # rloo, group_norm, or reinforce
|
|
min_completion_prefix: 8 # Skip anchors in prompt region
|
|
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_strided_structured.transform
|
|
split: train[:1%]
|
|
|
|
sequence_len: 2048
|
|
micro_batch_size: 1
|
|
gradient_accumulation_steps: 2
|
|
|
|
adapter: lora
|
|
lora_r: 16
|
|
lora_alpha: 32
|
|
lora_target_linear: true
|
|
|
|
bf16: auto
|
|
flex_attention: true
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true # Required with flex_attention
|
|
```
|
|
|
|
Run with a single command (no vLLM needed):
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
|
```
|
|
|
|
### Advantage Estimators
|
|
|
|
Strided mode supports three advantage estimation methods:
|
|
|
|
| Estimator | Formula | Requirements |
|
|
|-----------|---------|-------------|
|
|
| `rloo` | Leave-one-out baseline: `reward_j - mean(rewards_{-j})` | `n_samples_per_prompt >= 2` |
|
|
| `group_norm` | Group normalization: `(reward_j - mean) / std` | `n_samples_per_prompt >= 2` |
|
|
| `reinforce` | Raw reward as advantage (no baseline) | Works with `n_samples_per_prompt = 1` |
|
|
|
|
::: {.callout-warning}
|
|
When `n_samples_per_prompt: 1`, the trainer automatically falls back to `reinforce` and disables the diversity penalty (which requires multiple samples).
|
|
:::
|
|
|
|
### Strided Mode Constraints
|
|
|
|
- **`flex_attention: true`** is strongly recommended. Without it, dense 4D masks consume significantly more memory.
|
|
- **`torch_compile: true`** must NOT be set. `flex_attention` compiles its own kernels internally; adding `torch_compile` causes conflicts and OOM.
|
|
- **Gradient checkpointing** must use `use_reentrant: true`. Non-reentrant checkpointing causes `CheckpointError` with `flex_attention` block masks.
|
|
- **`activation_offloading`** is incompatible with `flex_attention`.
|
|
|
|
### Cross-Entropy Loss
|
|
|
|
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
|
|
|
|
```yaml
|
|
ebft:
|
|
ce_coef: 0.03 # Small CE coefficient
|
|
rl_coef: 1.0 # RL loss coefficient
|
|
```
|
|
|
|
The total loss is `rl_coef * rl_loss + ce_coef * ce_loss`. For structured mode, `ce_coef` is typically `0.0` since vLLM generation provides sufficient learning signal.
|
|
|
|
## Dataset Formats
|
|
|
|
EBFT provides several built-in dataset transforms in `src/axolotl/prompt_strategies/ebft/`.
|
|
|
|
### Built-In Transforms
|
|
|
|
| Transform | Input Format | Output Fields | Use Case |
|
|
|-----------|-------------|---------------|----------|
|
|
| `ebft_opencode.transform` | `{input, output}` | `{prompt, ground_truth}` | OpenCodeInstruct, structured QA |
|
|
| `ebft_strided_structured.transform` | `{input, output}` | `{input_ids, labels, prompt_length}` | Strided mode with structured data |
|
|
| `ebft_strided_chat.transform` | `{messages: [...]}` | `{input_ids, labels, prompt_length}` | Strided mode with chat data |
|
|
| `ebft_chat_multiturn.transform` | `{messages: [...]}` | `{prompt, ground_truth, remaining_turns}` | Multi-turn: first-turn target |
|
|
| `ebft_chat_multiturn.transform_last_turn` | `{messages: [...]}` | `{prompt, ground_truth}` | Multi-turn: last-turn target |
|
|
| `ebft_chat_multiturn.transform_all_turns` | `{messages: [...]}` | `{prompt[], ground_truth[]}` | Multi-turn: one example per turn |
|
|
| `ebft_reasoning.transform` | `{messages: [...]}` (with `<think>`) | `{prompt, ground_truth}` | Reasoning/thinking datasets |
|
|
|
|
### Structured Mode Datasets
|
|
|
|
For structured (sync/async) mode, the transform must produce `prompt` and `ground_truth` fields:
|
|
|
|
```yaml
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_opencode.transform
|
|
split: train[:500]
|
|
```
|
|
|
|
### Multi-Turn Datasets
|
|
|
|
Multi-turn transforms extract conversation data for sequential rollout. The `transform` variant targets the first assistant turn, while `transform_last_turn` targets the final turn:
|
|
|
|
```yaml
|
|
datasets:
|
|
- path: your/multiturn-dataset
|
|
type: ebft_chat_multiturn.transform
|
|
```
|
|
|
|
When `remaining_turns` is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
|
|
|
|
### Strided Mode Datasets
|
|
|
|
Strided transforms tokenize the full document and produce `input_ids`, `labels`, and `prompt_length`:
|
|
|
|
```yaml
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_strided_structured.transform
|
|
split: train[:1%]
|
|
```
|
|
|
|
### Custom Transforms
|
|
|
|
To use your own dataset format, write a transform function:
|
|
|
|
```python
|
|
def transform(cfg, **kwargs):
|
|
def transform_fn(example, tokenizer=None):
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]}],
|
|
"ground_truth": example["answer"],
|
|
}
|
|
return transform_fn, {"remove_columns": "__all__"}
|
|
```
|
|
|
|
The `"__all__"` sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:
|
|
|
|
```yaml
|
|
datasets:
|
|
- path: your/dataset
|
|
type: your_module.transform
|
|
```
|
|
|
|
## Configuration Reference
|
|
|
|
### Common Parameters (All Modes)
|
|
|
|
These parameters are set under the `ebft:` key in the YAML config.
|
|
|
|
| Parameter | Type | Default | Description |
|
|
|-----------|------|---------|-------------|
|
|
| `mode` | `"structured"` or `"strided"` | `"structured"` | EBFT operating mode |
|
|
| `feature_layers` | `list[float]` | `[0.25, 0.5, 0.75]` | Fractional layer depths for feature extraction |
|
|
| `embed_method` | `string` | `"last_token"` | Pooling method: `last_token`, `mean_pooling`, `completion_mean`, or `concat` |
|
|
| `use_whitening` | `bool` | `false` | Apply SVD whitening to feature embeddings before reward computation |
|
|
| `alignment_coef` | `float` | `1.0` | Weight for alignment reward (cosine similarity with ground truth) |
|
|
| `diversity_coef` | `float` | `1.0` | Weight for diversity penalty (pairwise dot product between samples) |
|
|
| `ce_coef` | `float` | `0.0` | Cross-entropy loss coefficient on ground-truth tokens |
|
|
| `adaptive_max_tokens` | `bool` | `true` | Dynamically set vLLM `max_tokens` based on ground-truth length (structured mode) |
|
|
| `gt_length_multiplier` | `float` | `1.5` | Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1) |
|
|
|
|
### Strided Mode Parameters
|
|
|
|
These additional parameters apply only when `mode: strided`.
|
|
|
|
| Parameter | Type | Default | Description |
|
|
|-----------|------|---------|-------------|
|
|
| `stride` | `int` | `8` | Number of tokens between anchor points (must be >= 1) |
|
|
| `context_length` | `int` | `8` | Context window size for each generated block (must be >= 1) |
|
|
| `generate_max_len` | `int` | `8` | Number of tokens to generate per block (must be >= 1) |
|
|
| `n_samples_per_prompt` | `int` | `4` | Number of independent rollouts per document (must be >= 1) |
|
|
| `temperature` | `float` | `0.6` | Sampling temperature for strided generation |
|
|
| `top_p` | `float` | `1.0` | Top-p nucleus sampling threshold |
|
|
| `rl_coef` | `float` | `1.0` | RL policy gradient loss coefficient |
|
|
| `advantage_estimator` | `string` | `"rloo"` | Advantage estimation method: `rloo`, `group_norm`, or `reinforce` |
|
|
| `min_completion_prefix` | `int` | `0` | Minimum tokens into the completion span before placing anchors |
|
|
|
|
### Structured Mode TRL Parameters
|
|
|
|
These are set under the `trl:` key and control the GRPO training loop.
|
|
|
|
| Parameter | Type | Default | Description |
|
|
|-----------|------|---------|-------------|
|
|
| `num_generations` | `int` | -- | Number of completions generated per prompt |
|
|
| `max_completion_length` | `int` | -- | Maximum tokens per generated completion |
|
|
| `temperature` | `float` | `0.7` | Sampling temperature for vLLM generation |
|
|
| `use_vllm` | `bool` | -- | Enable vLLM generation backend |
|
|
| `vllm_lora_sync` | `bool` | `false` | Sync LoRA adapters via filesystem (recommended) |
|
|
| `vllm_sync_interval` | `int` | `1` | Steps between weight syncs to vLLM |
|
|
| `use_data_producer` | `bool` | -- | Required for sync mode with LoRA sync |
|
|
| `async_prefetch` | `bool` | `false` | Enable async generation (overlaps with training) |
|
|
| `streaming_partial_batch` | `bool` | `false` | Score groups incrementally (async mode) |
|
|
| `skip_zero_advantage_batches` | `bool` | `false` | Skip micro-batches where all advantages are zero |
|
|
| `scale_rewards` | `bool` | -- | Normalize rewards within each prompt group |
|
|
| `loss_type` | `string` | `"grpo"` | Loss type for policy optimization |
|
|
| `epsilon` | `float` | `0.2` | Clipping parameter for importance sampling |
|
|
|
|
### Stop Tokens
|
|
|
|
vLLM needs explicit stop token IDs for generation. Common configurations:
|
|
|
|
```yaml
|
|
trl:
|
|
generation_kwargs:
|
|
stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>
|
|
```
|
|
|
|
### Multi-Turn Chat Settings
|
|
|
|
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent `<think>` tags in completions:
|
|
|
|
```yaml
|
|
trl:
|
|
chat_template_kwargs:
|
|
enable_thinking: false
|
|
```
|
|
|
|
## Monitoring
|
|
|
|
### Key Metrics
|
|
|
|
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
|
|
|
|
| Metric | Healthy Range | Interpretation |
|
|
|--------|--------------|----------------|
|
|
| `ebft/alignment` | 0.3 -- 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
|
|
| `ebft/diversity` | 0.01 -- 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
|
|
| `ebft/cfm_loss` | Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
|
|
| `ebft/reward` | Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
|
|
| `grad_norm` | 0.1 -- 3.0 | Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability. |
|
|
| `entropy` | 0.05 -- 0.5 | Policy entropy. Values below 0.01 suggest mode collapse. |
|
|
| `IS ratio min` | Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase `vllm_sync_interval`. |
|
|
|
|
### Console Log Example
|
|
|
|
During training, you will see periodic EBFT reward logs:
|
|
|
|
```
|
|
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
|
|
```
|
|
|
|
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
|
|
|
|
### Troubleshooting
|
|
|
|
| Symptom | Likely Cause | Fix |
|
|
|---------|-------------|-----|
|
|
| `alignment` stays below 0.1 | Feature layers not capturing useful information | Try different `feature_layers` or `embed_method` |
|
|
| `diversity` exceeds 1.0 | Mode collapse -- generations are too similar | Increase `diversity_coef` or `temperature` |
|
|
| `reward` stuck at -1.0 | Diversity penalty dominates alignment | Reduce `diversity_coef` or increase `alignment_coef` |
|
|
| `grad_norm` consistently 0.0 | All micro-batches have zero advantage | Increase `num_generations` or check data quality |
|
|
| `CheckpointError` in strided mode | Incompatible gradient checkpointing settings | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
|
| OOM during training | Logits tensor too large | Reduce `sequence_len` or `micro_batch_size`; strided mode uses chunked lm_head to mitigate this |
|
|
| vLLM 500 errors | `truncate_prompt_tokens` not supported | Ensure you are using `axolotl vllm-serve` (not `trl vllm-serve`) |
|
|
|
|
### Feature Network Memory
|
|
|
|
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the `disable_adapter()` context manager. This saves an entire model copy in VRAM (approximately 1--16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
|
|
|
|
::: {.callout-note}
|
|
The `disable_adapter()` approach relies on an invariant: `merge_adapter()` is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
|
|
:::
|
|
|
|
## Examples
|
|
|
|
Complete example configurations are available in `examples/ebft/`:
|
|
|
|
| Config | Model | Mode | Description |
|
|
|--------|-------|------|-------------|
|
|
| `llama-1b-ebft-strided-structured.yaml` | Llama 3.2 1B | Strided | Single-GPU strided training on code data |
|
|
| `qwen3-4b-ebft-structured.yaml` | Qwen3 4B | Structured (sync) | Two-GPU structured training |
|
|
| `qwen3-4b-ebft-structured-async.yaml` | Qwen3 4B | Structured (async) | Two-GPU async training with prefetch |
|
|
| `qwen3-8b-ebft-structured.yaml` | Qwen3 8B | Structured (sync) | Two-GPU structured training for larger model |
|
|
| `qwen35-4b-ebft-structured.yaml` | Qwen3.5 4B | Structured (sync) | Two-GPU with Qwen3.5 |
|
|
| `qwen35-4b-ebft-structured-async.yaml` | Qwen3.5 4B | Structured (async) | Two-GPU async with Qwen3.5 |
|
|
| `qwen35-9b-ebft-structured.yaml` | Qwen3.5 9B | Structured (sync) | Two-GPU structured for 9B model |
|