EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
214
examples/ebft/README.md
Normal file
214
examples/ebft/README.md
Normal file
@@ -0,0 +1,214 @@
|
||||
# Energy-Based Fine-Tuning (EBFT)
|
||||
|
||||
EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
|
||||
|
||||
## Overview
|
||||
|
||||
EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
||||
|
||||
**Key advantages over SFT:**
|
||||
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
||||
- Provides dense sequence-level supervision without a task-specific verifier
|
||||
- Improves both downstream accuracy and validation cross-entropy simultaneously
|
||||
|
||||
**Key advantages over RLVR:**
|
||||
- No reward model or verifier required — works on any (prompt, completion) data
|
||||
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
|
||||
- Maintains distributional calibration (low feature-matching loss)
|
||||
|
||||
## Two Modes
|
||||
|
||||
EBFT supports two modes depending on your data format:
|
||||
|
||||
### Structured Mode (`mode: structured`, default)
|
||||
For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
|
||||
- Extends GRPOTrainer — uses vLLM for fast rollout generation
|
||||
- RLOO advantages and clipped policy gradient from GRPO
|
||||
- Feature-matching rewards replace external reward functions
|
||||
|
||||
### Strided Mode (`mode: strided`)
|
||||
For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
|
||||
- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
|
||||
- No vLLM needed — generation uses custom strided attention masks
|
||||
- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
|
||||
- Compatible with gradient checkpointing via automatic dtype normalization
|
||||
- This is the core EBFT algorithm from the paper (Section F)
|
||||
|
||||
### Common to both modes:
|
||||
- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
|
||||
- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
|
||||
- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
|
||||
- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
|
||||
- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
|
||||
- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Structured Mode (QA data + vLLM)
|
||||
|
||||
```bash
|
||||
# 1. Start vLLM server
|
||||
python -m trl.scripts.vllm_serve \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--host 0.0.0.0 --port 8000 \
|
||||
--gpu-memory-utilization 0.3
|
||||
|
||||
# 2. Train
|
||||
axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
|
||||
```
|
||||
|
||||
### Strided Mode (unstructured text)
|
||||
|
||||
```bash
|
||||
# No vLLM needed — strided generation is built-in
|
||||
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Common EBFT Settings
|
||||
|
||||
```yaml
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
# Feature network: which layers to extract hidden states from
|
||||
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
|
||||
# How to pool per-token hidden states into sequence embeddings
|
||||
# Options: "last_token" (recommended), "mean_pooling", "concat"
|
||||
embed_method: last_token
|
||||
|
||||
# SVD whitening — strongly recommended (paper shows largest degradation without it)
|
||||
use_whitening: true
|
||||
|
||||
# Reward = alignment_coef * alignment - diversity_coef * diversity
|
||||
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
|
||||
# diversity uses raw dot product — both are bounded after whitening.
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
|
||||
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
|
||||
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
|
||||
ce_coef: 0.0
|
||||
```
|
||||
|
||||
### Strided Mode Settings
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8 # tokens between anchor points (paper default: 8)
|
||||
context_length: 8 # context window per block (paper default: 8)
|
||||
generate_max_len: 8 # tokens generated per block (paper default: 8)
|
||||
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
|
||||
temperature: 0.6
|
||||
rl_coef: 1.0 # RL loss weight
|
||||
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
|
||||
```
|
||||
|
||||
### Structured Mode Settings (via TRL)
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt
|
||||
max_completion_length: 256 # max tokens to generate
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
```
|
||||
|
||||
### Dataset Format
|
||||
|
||||
**Structured mode** — QA data with prompt + ground-truth completion:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
```
|
||||
Transform returns: `{"prompt": ..., "ground_truth": ...}`
|
||||
|
||||
**Strided mode** — raw text tokenized to fixed length:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
```
|
||||
Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
|
||||
|
||||
## How It Works
|
||||
|
||||
### Structured Mode
|
||||
1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
|
||||
2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
|
||||
3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
|
||||
4. **RLOO advantages**: subtract leave-one-out group mean
|
||||
5. **Policy gradient**: clipped PPO-style loss
|
||||
|
||||
### Strided Mode
|
||||
1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
|
||||
2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
|
||||
3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
|
||||
4. **Per-block rewards**:
|
||||
- **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
|
||||
- **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
|
||||
- **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
|
||||
5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
|
||||
6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
|
||||
|
||||
### Tracked Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
|
||||
| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
|
||||
| `ebft/mean_reward` | alignment - diversity (should trend upward) |
|
||||
| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
|
||||
| `ebft/rl_loss` | REINFORCE policy gradient loss |
|
||||
| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
|
||||
| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
|
||||
|
||||
## Tips and Recommendations
|
||||
|
||||
### Reward coefficients
|
||||
- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
|
||||
- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
|
||||
- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
|
||||
- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
|
||||
|
||||
### Feature extraction
|
||||
- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
|
||||
- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
|
||||
|
||||
### Performance
|
||||
- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
|
||||
- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
|
||||
|
||||
### Memory
|
||||
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
|
||||
- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
|
||||
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
|
||||
- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
|
||||
|
||||
## Example Configs
|
||||
|
||||
| Config | Mode | Model | Description |
|
||||
|--------|------|-------|-------------|
|
||||
| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
|
||||
| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
|
||||
| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
|
||||
| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{jelassi2026matching,
|
||||
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
|
||||
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
|
||||
journal={arXiv preprint arXiv:2603.12248},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
28
examples/ebft/ebft_opencode.py
Normal file
28
examples/ebft/ebft_opencode.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Dataset transform for nvidia/OpenCodeInstruct with EBFT.
|
||||
|
||||
Maps the dataset's `input` (prompt) and `output` (code solution) fields
|
||||
to the format expected by the EBFT trainer.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "user", "content": example["input"]},
|
||||
],
|
||||
"ground_truth": example["output"],
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": [
|
||||
"id",
|
||||
"domain",
|
||||
"generation_algorithm",
|
||||
"llm_judgement",
|
||||
"unit_tests",
|
||||
"tests_execution_status",
|
||||
"average_test_score",
|
||||
]
|
||||
}
|
||||
31
examples/ebft/ebft_pretrain.py
Normal file
31
examples/ebft/ebft_pretrain.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Dataset transform for unstructured text data with strided EBFT.
|
||||
|
||||
Tokenizes raw text into fixed-length input_ids for the strided trainer.
|
||||
Sequences are padded to sequence_len for uniform batching.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
text = example.get("question", example.get("text", ""))
|
||||
if tokenizer is None:
|
||||
return {"prompt": text}
|
||||
|
||||
encoded = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
padding="max_length",
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"labels": list(encoded["input_ids"]),
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": ["question", "answer"]}
|
||||
80
examples/ebft/ebft_strided_structured.py
Normal file
80
examples/ebft/ebft_strided_structured.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
64
examples/ebft/llama-1b-ebft-opencode-novllm.yaml
Normal file
64
examples/ebft/llama-1b-ebft-opencode-novllm.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
# EBFT validation config — no vLLM, uses HF generate for simplicity
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 128
|
||||
temperature: 1.0
|
||||
use_vllm: false
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 512
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-validation
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_run_id:
|
||||
wandb_watch:
|
||||
wandb_log_model:
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
81
examples/ebft/llama-1b-ebft-opencode.yaml
Normal file
81
examples/ebft/llama-1b-ebft-opencode.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
|
||||
#
|
||||
# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
|
||||
# https://arxiv.org/abs/2603.12248
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM server on a separate GPU:
|
||||
# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
|
||||
# --model meta-llama/Llama-3.2-1B \
|
||||
# --host 0.0.0.0 --port 8000 \
|
||||
# --gpu-memory-utilization 0.4 --dtype bfloat16
|
||||
#
|
||||
# 2. Run training:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
|
||||
# --- Training method ---
|
||||
rl: ebft
|
||||
|
||||
# --- EBFT configuration ---
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
|
||||
embed_method: last_token # pool to sequence embedding via last token
|
||||
use_whitening: false # SVD whitening (disable for speed in small runs)
|
||||
alignment_coef: 1.0 # cosine similarity with ground-truth features
|
||||
diversity_coef: 1.0 # pairwise similarity penalty
|
||||
ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
|
||||
|
||||
# --- Generation settings (via TRL/GRPO infrastructure) ---
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt for RLOO
|
||||
max_completion_length: 256 # max generated tokens
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
# --- Dataset ---
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%] # first 1% for validation runs
|
||||
|
||||
# --- Training hyperparameters ---
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 50
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
weight_decay: 0.01
|
||||
|
||||
# --- LoRA (recommended to reduce memory with frozen feature network) ---
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# --- Hardware ---
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama-1b-opencode
|
||||
|
||||
# --- Logging ---
|
||||
use_tensorboard: true
|
||||
logging_steps: 1
|
||||
save_steps: 25
|
||||
65
examples/ebft/llama-1b-ebft-strided-structured.yaml
Normal file
65
examples/ebft/llama-1b-ebft-strided-structured.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
# EBFT Strided Structured Mode: For structured (prompt, completion) data
|
||||
# Uses strided block-parallel generation on completion spans — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.03 # small CE weight for structured data
|
||||
advantage_estimator: rloo
|
||||
min_completion_prefix: 8 # skip anchors too close to prompt boundary
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 2048
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
# max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
60
examples/ebft/llama-1b-ebft-strided.yaml
Normal file
60
examples/ebft/llama-1b-ebft-strided.yaml
Normal file
@@ -0,0 +1,60 @@
|
||||
# EBFT Strided Mode: For unstructured text data (raw code, prose)
|
||||
# Uses strided block-parallel generation — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:100]
|
||||
|
||||
sequence_len: 256
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 5
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-validation
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
69
examples/ebft/llama-3b-ebft-strided-fft.yaml
Normal file
69
examples/ebft/llama-3b-ebft-strided-fft.yaml
Normal file
@@ -0,0 +1,69 @@
|
||||
# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
|
||||
# Actor on GPU 0, frozen feature network on GPU 1
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
torch_dtype: bfloat16
|
||||
flash_attention: false
|
||||
gradient_checkpointing: true
|
||||
torch_compile: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
ddp: false
|
||||
device_map:
|
||||
"": 0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama3b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama3b-strided-lora-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
58
examples/ebft/llama-8b-ebft-strided-fft.yaml
Normal file
58
examples/ebft/llama-8b-ebft-strided-fft.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
|
||||
# Feature network is CPU-offloaded to fit in single 32GB GPU
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama8b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama8b-strided-fft-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
86
examples/ebft/qwen35-4b-ebft-structured-async.yaml
Normal file
86
examples/ebft/qwen35-4b-ebft-structured-async.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
async_prefetch: true
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured-async
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
77
examples/ebft/qwen35-4b-ebft-structured.yaml
Normal file
77
examples/ebft/qwen35-4b-ebft-structured.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
|
||||
# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
82
examples/ebft/qwen35-9b-ebft-structured.yaml
Normal file
82
examples/ebft/qwen35-9b-ebft-structured.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.7
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 3.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-9b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
Reference in New Issue
Block a user