EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
This commit is contained in:
Wing Lian
2026-03-24 18:43:46 -04:00
committed by GitHub
parent e9883c91d4
commit c50c4acbf4
48 changed files with 5885 additions and 168 deletions

View File

@@ -22,6 +22,7 @@ RUN apt update && \
chmod 700 ~/.ssh && \ chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \ chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf echo 'set-option -g history-limit 5000' >> ~/.tmux.conf

214
examples/ebft/README.md Normal file
View File

@@ -0,0 +1,214 @@
# Energy-Based Fine-Tuning (EBFT)
EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
## Overview
EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
**Key advantages over SFT:**
- Operates on model rollouts (not teacher forcing), reducing distribution shift
- Provides dense sequence-level supervision without a task-specific verifier
- Improves both downstream accuracy and validation cross-entropy simultaneously
**Key advantages over RLVR:**
- No reward model or verifier required — works on any (prompt, completion) data
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
- Maintains distributional calibration (low feature-matching loss)
## Two Modes
EBFT supports two modes depending on your data format:
### Structured Mode (`mode: structured`, default)
For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
- Extends GRPOTrainer — uses vLLM for fast rollout generation
- RLOO advantages and clipped policy gradient from GRPO
- Feature-matching rewards replace external reward functions
### Strided Mode (`mode: strided`)
For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
- No vLLM needed — generation uses custom strided attention masks
- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
- Compatible with gradient checkpointing via automatic dtype normalization
- This is the core EBFT algorithm from the paper (Section F)
### Common to both modes:
- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
## Quick Start
### Structured Mode (QA data + vLLM)
```bash
# 1. Start vLLM server
python -m trl.scripts.vllm_serve \
--model meta-llama/Llama-3.2-1B \
--host 0.0.0.0 --port 8000 \
--gpu-memory-utilization 0.3
# 2. Train
axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
```
### Strided Mode (unstructured text)
```bash
# No vLLM needed — strided generation is built-in
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
```
## Configuration
### Common EBFT Settings
```yaml
rl: ebft
ebft:
# Feature network: which layers to extract hidden states from
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
feature_layers: [0.25, 0.5, 0.75]
# How to pool per-token hidden states into sequence embeddings
# Options: "last_token" (recommended), "mean_pooling", "concat"
embed_method: last_token
# SVD whitening — strongly recommended (paper shows largest degradation without it)
use_whitening: true
# Reward = alignment_coef * alignment - diversity_coef * diversity
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
# diversity uses raw dot product — both are bounded after whitening.
alignment_coef: 1.0
diversity_coef: 1.0
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
ce_coef: 0.0
```
### Strided Mode Settings
```yaml
ebft:
mode: strided
stride: 8 # tokens between anchor points (paper default: 8)
context_length: 8 # context window per block (paper default: 8)
generate_max_len: 8 # tokens generated per block (paper default: 8)
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
temperature: 0.6
rl_coef: 1.0 # RL loss weight
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
```
### Structured Mode Settings (via TRL)
```yaml
trl:
num_generations: 4 # samples per prompt
max_completion_length: 256 # max tokens to generate
temperature: 1.0
use_vllm: true
scale_rewards: true
loss_type: grpo
epsilon: 0.2
```
### Dataset Format
**Structured mode** — QA data with prompt + ground-truth completion:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
```
Transform returns: `{"prompt": ..., "ground_truth": ...}`
**Strided mode** — raw text tokenized to fixed length:
```yaml
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
```
Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
## How It Works
### Structured Mode
1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
4. **RLOO advantages**: subtract leave-one-out group mean
5. **Policy gradient**: clipped PPO-style loss
### Strided Mode
1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
4. **Per-block rewards**:
- **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
- **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
- **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
### Tracked Metrics
| Metric | Description |
|--------|-------------|
| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
| `ebft/mean_reward` | alignment - diversity (should trend upward) |
| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
| `ebft/rl_loss` | REINFORCE policy gradient loss |
| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
## Tips and Recommendations
### Reward coefficients
- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
### Feature extraction
- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
### Performance
- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
### Memory
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
## Example Configs
| Config | Mode | Model | Description |
|--------|------|-------|-------------|
| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
## Citation
```bibtex
@article{jelassi2026matching,
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
journal={arXiv preprint arXiv:2603.12248},
year={2026}
}
```

View File

@@ -0,0 +1,28 @@
"""
Dataset transform for nvidia/OpenCodeInstruct with EBFT.
Maps the dataset's `input` (prompt) and `output` (code solution) fields
to the format expected by the EBFT trainer.
"""
def transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "user", "content": example["input"]},
],
"ground_truth": example["output"],
}
return transform_fn, {
"remove_columns": [
"id",
"domain",
"generation_algorithm",
"llm_judgement",
"unit_tests",
"tests_execution_status",
"average_test_score",
]
}

View File

@@ -0,0 +1,31 @@
"""
Dataset transform for unstructured text data with strided EBFT.
Tokenizes raw text into fixed-length input_ids for the strided trainer.
Sequences are padded to sequence_len for uniform batching.
"""
def transform(cfg, *args, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
text = example.get("question", example.get("text", ""))
if tokenizer is None:
return {"prompt": text}
encoded = tokenizer(
text,
truncation=True,
max_length=seq_len,
padding="max_length",
add_special_tokens=True,
return_tensors=None,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": list(encoded["input_ids"]),
}
return transform_fn, {"remove_columns": ["question", "answer"]}

View File

@@ -0,0 +1,80 @@
"""
Dataset transform for structured (prompt, completion) data with strided EBFT.
Tokenizes prompt and completion separately, concatenates into a single
input_ids sequence, and marks prompt tokens with labels=-100 so the
strided trainer knows where to place anchors (completion span only).
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
"""
def transform(cfg, *args, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
# Extract prompt and completion from the example
prompt_text = example.get(
"input", example.get("prompt", example.get("question", ""))
)
completion_text = example.get(
"output", example.get("completion", example.get("answer", ""))
)
if tokenizer is None:
return {"prompt": prompt_text}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize prompt and completion separately
prompt_enc = tokenizer(
prompt_text,
truncation=False,
add_special_tokens=True,
return_tensors=None,
)
completion_enc = tokenizer(
completion_text,
truncation=False,
add_special_tokens=False,
return_tensors=None,
)
prompt_ids = prompt_enc["input_ids"]
completion_ids = completion_enc["input_ids"]
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
total_len = len(prompt_ids) + len(completion_ids)
if total_len > seq_len:
# Truncate completion first, then prompt if needed
max_completion = seq_len - len(prompt_ids)
if max_completion < 1:
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
prompt_ids = prompt_ids[: seq_len - 1]
completion_ids = completion_ids[:1]
else:
completion_ids = completion_ids[:max_completion]
input_ids = prompt_ids + completion_ids
prompt_length = len(prompt_ids)
# Labels: -100 for prompt tokens, input_ids for completion tokens
labels = [-100] * prompt_length + completion_ids
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
# Signal to remove all original columns (filtered to existing ones at map time)
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,64 @@
# EBFT validation config — no vLLM, uses HF generate for simplicity
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
base_model: meta-llama/Llama-3.2-1B
chat_template: llama3
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 128
temperature: 1.0
use_vllm: false
scale_rewards: true
loss_type: grpo
epsilon: 0.2
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:1%]
sequence_len: 512
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
max_steps: 10
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 2
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-validation
wandb_project: ebft
wandb_run_id:
wandb_watch:
wandb_log_model:
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,81 @@
# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
#
# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
# https://arxiv.org/abs/2603.12248
#
# Prerequisites:
# 1. Start vLLM server on a separate GPU:
# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
# --model meta-llama/Llama-3.2-1B \
# --host 0.0.0.0 --port 8000 \
# --gpu-memory-utilization 0.4 --dtype bfloat16
#
# 2. Run training:
# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
base_model: meta-llama/Llama-3.2-1B
chat_template: llama3
# --- Training method ---
rl: ebft
# --- EBFT configuration ---
ebft:
feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
embed_method: last_token # pool to sequence embedding via last token
use_whitening: false # SVD whitening (disable for speed in small runs)
alignment_coef: 1.0 # cosine similarity with ground-truth features
diversity_coef: 1.0 # pairwise similarity penalty
ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
# --- Generation settings (via TRL/GRPO infrastructure) ---
trl:
num_generations: 4 # samples per prompt for RLOO
max_completion_length: 256 # max generated tokens
temperature: 1.0
use_vllm: true
scale_rewards: true
loss_type: grpo
epsilon: 0.2
# --- Dataset ---
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:1%] # first 1% for validation runs
# --- Training hyperparameters ---
sequence_len: 1024
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 50
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 5
weight_decay: 0.01
# --- LoRA (recommended to reduce memory with frozen feature network) ---
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
# --- Hardware ---
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama-1b-opencode
# --- Logging ---
use_tensorboard: true
logging_steps: 1
save_steps: 25

View File

@@ -0,0 +1,65 @@
# EBFT Strided Structured Mode: For structured (prompt, completion) data
# Uses strided block-parallel generation on completion spans — no vLLM needed.
#
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided # strided block-parallel generation
stride: 8 # tokens between anchor points
context_length: 8 # context window per block
generate_max_len: 8 # tokens to generate per block
n_samples_per_prompt: 4 # rollouts per document
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.03 # small CE weight for structured data
advantage_estimator: rloo
min_completion_prefix: 8 # skip anchors too close to prompt boundary
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 1
# max_steps: 10
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 5
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
# (full-model compile conflicts with gradient checkpointing + flex_attention)
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-strided-structured
wandb_project: ebft
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,60 @@
# EBFT Strided Mode: For unstructured text data (raw code, prose)
# Uses strided block-parallel generation — no vLLM needed.
#
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided # strided block-parallel generation
stride: 8 # tokens between anchor points
context_length: 8 # context window per block
generate_max_len: 8 # tokens to generate per block
n_samples_per_prompt: 4 # rollouts per document
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:100]
sequence_len: 256
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 1
max_steps: 5
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 2
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-strided-validation
wandb_project: ebft
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,69 @@
# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
# Actor on GPU 0, frozen feature network on GPU 1
#
# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
base_model: meta-llama/Llama-3.2-3B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:5000]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 100
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
weight_decay: 0.01
adapter: lora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
torch_dtype: bfloat16
flash_attention: false
gradient_checkpointing: true
torch_compile: true
gradient_checkpointing_kwargs:
use_reentrant: true
ddp: false
device_map:
"": 0
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama3b-strided
wandb_project: ebft
wandb_name: llama3b-strided-lora-100steps
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,58 @@
# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
# Feature network is CPU-offloaded to fit in single 32GB GPU
#
# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
base_model: meta-llama/Llama-3.1-8B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:5000]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 100
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
weight_decay: 0.01
bf16: auto
flash_attention: false # strided EBFT uses flex_attention at runtime
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama8b-strided
wandb_project: ebft
wandb_name: llama8b-strided-fft-100steps
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,86 @@
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
#
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
# full attention every 4th layer. This tests EBFT compatibility.
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
base_model: Qwen/Qwen3.5-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false
async_prefetch: true
vllm_server_timeout: 300
vllm:
gpu_memory_utilization: 0.5
max_model_len: 2048
serve_module: axolotl.scripts.vllm_serve_lora
enforce_eager: true
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-4b-structured-async
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,77 @@
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
#
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
# full attention every 4th layer. This tests EBFT compatibility.
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
base_model: Qwen/Qwen3.5-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-4b-structured
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,82 @@
# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
base_model: Qwen/Qwen3.5-9B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false
vllm_server_timeout: 300
vllm:
gpu_memory_utilization: 0.7
max_model_len: 2048
serve_module: axolotl.scripts.vllm_serve_lora
enforce_eager: true
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 3.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-9b-structured
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -38,18 +38,14 @@ def do_vllm_serve(
cfg = load_cfg(config) cfg = load_cfg(config)
model = cfg.base_model model = cfg.base_model
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default # Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
# We default to axolotl's serve module instead of TRL's because TRL's sends
# truncate_prompt_tokens which is unsupported in vLLM 0.17+.
serve_module = cli_args.get("serve_module") or getattr( serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None cfg.vllm, "serve_module", None
) )
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None: if serve_module is None:
serve_module = "trl.scripts.vllm_serve" serve_module = "axolotl.scripts.vllm_serve_lora"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1 tensor_parallel_size = 1
data_parallel_size = 1 data_parallel_size = 1
@@ -79,6 +75,12 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
) )
cli_enforce_eager = cli_args.get("enforce_eager")
cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
raw_enforce_eager = (
cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
)
enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
base_kwargs = dict( base_kwargs = dict(
model=model, model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
@@ -89,6 +91,7 @@ def do_vllm_serve(
dtype=dtype, dtype=dtype,
max_model_len=max_model_len, max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
enforce_eager=enforce_eager,
) )
# Use LoRAScriptArguments when serving with native LoRA support # Use LoRAScriptArguments when serving with native LoRA support
@@ -98,6 +101,10 @@ def do_vllm_serve(
lora_kwargs = {} lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r: if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r lora_kwargs["max_lora_rank"] = cfg.lora_r
# Disable native LoRA in vLLM if not using vllm_lora_sync
# (merged weight sync via batch_update doesn't need vLLM LoRA mode)
if not getattr(cfg.trl, "vllm_lora_sync", False):
lora_kwargs["enable_lora"] = False
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs) vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else: else:
vllm_script_args = AxolotlScriptArguments( vllm_script_args = AxolotlScriptArguments(

View File

@@ -118,7 +118,7 @@ def load_preference_datasets(
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer) train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
total_num_steps: int | None = None total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO: if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )

View File

@@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = AxolotlKTOTrainer trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO: elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer trainer_cls = AxolotlCPOTrainer
elif self.cfg.rl is RLType.EBFT:
from axolotl.core.trainers.ebft import EBFTStrategy
trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
elif self.cfg.rl is RLType.EBFT:
from axolotl.core.trainers.ebft import EBFTStrategy
training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if ( if (
self.cfg.adapter self.cfg.adapter
and self.peft_config and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO) and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
): ):
trainer_kwargs["peft_config"] = self.peft_config trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:

View File

@@ -4,6 +4,8 @@
from .base import AxolotlTrainer from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer from .dpo.trainer import AxolotlDPOTrainer
from .ebft.strided import AxolotlStridedEBFTTrainer
from .ebft.trainer import AxolotlEBFTTrainer
from .mamba import AxolotlMambaTrainer from .mamba import AxolotlMambaTrainer
from .trl import ( from .trl import (
AxolotlCPOTrainer, AxolotlCPOTrainer,

View File

@@ -0,0 +1,213 @@
"""EBFT (Energy-Based Fine-Tuning) Strategy for training
Two modes:
- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
"""
from typing import Any
from axolotl.core.trainers.ebft.args import (
AxolotlAsyncEBFTConfig,
AxolotlEBFTConfig,
AxolotlStridedEBFTConfig,
)
from axolotl.utils.dict import DictDefault
def _get_ebft_mode(cfg: DictDefault) -> str:
"""Determine EBFT mode from config."""
if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
return cfg.ebft.mode
return "structured"
class EBFTStrategy:
"""Strategy for EBFT training — dispatches between structured and strided modes."""
@classmethod
def get_trainer_class(cls, cfg: DictDefault | None = None):
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
return AxolotlStridedEBFTTrainer
# Structured mode: async or sync
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
if use_async:
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
return AxolotlAsyncEBFTTrainer
from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
return AxolotlEBFTTrainer
@classmethod
def get_training_args_class(cls, cfg: DictDefault | None = None):
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
return AxolotlStridedEBFTConfig
# Structured mode: async or sync config
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
if use_async:
return AxolotlAsyncEBFTConfig
return AxolotlEBFTConfig
@classmethod
def is_strided(cls, cfg: DictDefault) -> bool:
return _get_ebft_mode(cfg) == "strided"
@classmethod
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
"""Map axolotl YAML config fields to training args kwargs."""
kwargs: dict[str, Any] = {}
mode = _get_ebft_mode(cfg)
# Common EBFT fields
ebft = cfg.ebft
if ebft:
if ebft.feature_layers is not None:
kwargs["ebft_feature_layers"] = ebft.feature_layers
if ebft.embed_method is not None:
kwargs["ebft_embed_method"] = ebft.embed_method
if ebft.use_whitening is not None:
kwargs["ebft_use_whitening"] = ebft.use_whitening
if ebft.alignment_coef is not None:
kwargs["ebft_alignment_coef"] = ebft.alignment_coef
if ebft.diversity_coef is not None:
kwargs["ebft_diversity_coef"] = ebft.diversity_coef
if ebft.ce_coef is not None:
kwargs["ebft_ce_coef"] = ebft.ce_coef
if getattr(ebft, "adaptive_max_tokens", None) is not None:
kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
if getattr(ebft, "gt_length_multiplier", None) is not None:
kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
if mode == "strided":
# Strided-specific fields
if ebft:
if ebft.stride is not None:
kwargs["ebft_stride"] = ebft.stride
if ebft.context_length is not None:
kwargs["ebft_context_length"] = ebft.context_length
if ebft.generate_max_len is not None:
kwargs["ebft_generate_max_len"] = ebft.generate_max_len
if ebft.n_samples_per_prompt is not None:
kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
if ebft.temperature is not None:
kwargs["ebft_temperature"] = ebft.temperature
if ebft.top_p is not None:
kwargs["ebft_top_p"] = ebft.top_p
if ebft.rl_coef is not None:
kwargs["ebft_rl_coef"] = ebft.rl_coef
if ebft.advantage_estimator is not None:
kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
if ebft.min_completion_prefix is not None:
kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
else:
# Structured mode: map TRL config fields
trl = cfg.trl
if trl:
if trl.use_vllm:
kwargs["use_vllm"] = trl.use_vllm
if trl.vllm_mode:
kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
vllm_cfg = cfg.vllm
if vllm_cfg:
kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)
kwargs["vllm_tensor_parallel_size"] = (
vllm_cfg.tensor_parallel_size
)
kwargs["vllm_server_host"] = trl.vllm_server_host or (
trl.vllm.host if trl.vllm else None
)
kwargs["vllm_server_port"] = trl.vllm_server_port or (
trl.vllm.port if trl.vllm else None
)
if trl.vllm_server_timeout:
kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.num_generations:
kwargs["num_generations"] = trl.num_generations
if trl.max_completion_length is not None:
kwargs["max_completion_length"] = trl.max_completion_length
if trl.temperature is not None:
kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
kwargs["min_p"] = trl.min_p
if trl.num_iterations is not None:
kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
kwargs["epsilon_high"] = trl.epsilon_high
if trl.scale_rewards is not None:
kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.log_completions is not None:
kwargs["log_completions"] = trl.log_completions
if trl.num_completions_to_print is not None:
kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.sync_ref_model:
kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.repetition_penalty is not None:
kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.generation_kwargs is not None:
kwargs["generation_kwargs"] = trl.generation_kwargs
if trl.chat_template_kwargs is not None:
kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
# Async prefetch fields (only pass when enabled — sync config doesn't have these)
if getattr(trl, "async_prefetch", False):
kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "vllm_sync_interval", None) is not None:
kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "vllm_lora_sync", False):
kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return kwargs
@classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
return []
@classmethod
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
return {}
@classmethod
def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
return [
"dataset_num_proc",
"max_length",
"max_prompt_length",
"include_tokens_per_second",
"beta",
]
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod
def get_collator(cls, *args, **kwargs):
return None

View File

@@ -0,0 +1,133 @@
"""
EBFT-specific training arguments.
Two config classes:
- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
"""
from dataclasses import dataclass, field
from transformers import TrainingArguments
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
# -- Shared EBFT fields as a mixin --
@dataclass
class EBFTFieldsMixin:
"""Common fields shared between structured and strided EBFT configs."""
ebft_feature_layers: list[float] = field(
default_factory=lambda: [0.25, 0.5, 0.75],
metadata={"help": "Fractional layer depths for feature extraction"},
)
ebft_embed_method: str = field(
default="last_token",
metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
)
ebft_use_whitening: bool = field(
default=False,
metadata={"help": "Apply SVD whitening to feature embeddings"},
)
ebft_alignment_coef: float = field(
default=1.0,
metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
)
ebft_diversity_coef: float = field(
default=1.0,
metadata={"help": "Coefficient for diversity penalty"},
)
ebft_ce_coef: float = field(
default=0.0,
metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
)
ebft_adaptive_max_tokens: bool = field(
default=True,
metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
)
ebft_gt_length_multiplier: float = field(
default=1.5,
metadata={
"help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
},
)
# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
@dataclass
class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
"""EBFT config for structured QA data — extends GRPOConfig."""
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
},
)
# -- Async structured mode: extends FastAsyncGRPOConfig --
@dataclass
class AxolotlAsyncEBFTConfig(
EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
):
"""EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
Includes all async fields: async_prefetch, vllm_lora_sync,
skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
"""
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
},
)
# -- Strided mode: extends TrainingArguments for unstructured text --
@dataclass
class AxolotlStridedEBFTConfig(
EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
):
"""EBFT config for unstructured text with strided block-parallel generation."""
ebft_stride: int = field(
default=8,
metadata={"help": "Stride between anchor points (in tokens)"},
)
ebft_context_length: int = field(
default=8,
metadata={"help": "Context window size for each block"},
)
ebft_generate_max_len: int = field(
default=8,
metadata={"help": "Number of tokens to generate per block"},
)
ebft_n_samples_per_prompt: int = field(
default=4,
metadata={"help": "Number of independent rollouts per document"},
)
ebft_temperature: float = field(
default=0.6,
metadata={"help": "Sampling temperature for strided generation"},
)
ebft_top_p: float = field(
default=1.0,
metadata={"help": "Top-p nucleus sampling threshold"},
)
ebft_rl_coef: float = field(
default=1.0,
metadata={"help": "RL policy gradient loss coefficient"},
)
ebft_advantage_estimator: str = field(
default="rloo",
metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
)
ebft_min_completion_prefix: int = field(
default=0,
metadata={"help": "Minimum tokens into completion before placing anchors"},
)

View File

@@ -0,0 +1,308 @@
"""
Fused Triton kernels for strided EBFT.
These kernels eliminate intermediate tensor materializations that dominate
the elementwise/fill category (~40% of CUDA time in profiling).
Kernels:
1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
"""
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# 1. Fused log_softmax + gather (selective log softmax)
# ---------------------------------------------------------------------------
# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
# This avoids materializing the full (B, S, V) log_softmax output.
@triton.jit
def _fused_log_softmax_gather_kernel(
logits_ptr, # (B*S, V) row-major
labels_ptr, # (B*S,) int64
output_ptr, # (B*S,) float32
V: tl.constexpr, # vocab size
BLOCK_V: tl.constexpr, # tile width over vocab
):
"""Compute log_softmax(logits)[label] for each row without materializing full output."""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + row * V
label = tl.load(labels_ptr + row)
# Pass 1: find max for numerical stability
max_val = -float("inf")
for v_start in range(0, V, BLOCK_V):
v_offsets = v_start + tl.arange(0, BLOCK_V)
mask = v_offsets < V
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
max_val = tl.maximum(max_val, tl.max(vals, axis=0))
# Pass 2: compute sum(exp(x - max))
sum_exp = 0.0
for v_start in range(0, V, BLOCK_V):
v_offsets = v_start + tl.arange(0, BLOCK_V)
mask = v_offsets < V
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
log_sum_exp = tl.log(sum_exp)
# Gather: log_softmax[label] = logits[label] - max - log_sum_exp
target_logit = tl.load(logits_row_ptr + label)
result = target_logit - max_val - log_sum_exp
tl.store(output_ptr + row, result)
def fused_log_softmax_gather(
logits: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
Args:
logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
labels: (B, S) or (B*S,) int64 tensor of target indices
Returns:
(B, S) or (B*S,) float32 tensor of selected log probabilities
"""
orig_shape = logits.shape[:-1]
V = logits.shape[-1]
logits_2d = logits.reshape(-1, V).contiguous()
labels_1d = labels.reshape(-1).contiguous()
n_rows = logits_2d.shape[0]
output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
# Choose BLOCK_V: must be power of 2, large enough for good occupancy
BLOCK_V = min(triton.next_power_of_2(V), 65536)
_fused_log_softmax_gather_kernel[(n_rows,)](
logits_2d,
labels_1d,
output,
V=V,
BLOCK_V=BLOCK_V,
)
return output.view(orig_shape)
# ---------------------------------------------------------------------------
# 2. Fused masked REINFORCE loss reduction
# ---------------------------------------------------------------------------
# Instead of: (-logp * adv * mask).sum() / mask.sum()
# We do the masked multiply-accumulate in one kernel, returning (sum, count).
@triton.jit
def _fused_reinforce_loss_kernel(
logps_ptr, # (N,) float32 per-token log probs
advs_ptr, # (N,) float32 advantages
mask_ptr, # (N,) bool action mask
partial_sum_ptr, # (n_blocks,) partial sums
partial_cnt_ptr, # (n_blocks,) partial counts
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
block_id = tl.program_id(0)
offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
valid = offsets < N
logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
# -logp * advantage * mask
loss = -logps * advs * m
block_sum = tl.sum(loss, axis=0)
block_cnt = tl.sum(m, axis=0)
tl.store(partial_sum_ptr + block_id, block_sum)
tl.store(partial_cnt_ptr + block_id, block_cnt)
def fused_reinforce_loss(
per_token_logps: torch.Tensor,
advantages: torch.Tensor,
action_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
All inputs should be flat or will be flattened. Returns scalar loss tensor.
"""
logps_flat = per_token_logps.reshape(-1).contiguous()
advs_flat = advantages.reshape(-1).contiguous()
mask_flat = action_mask.reshape(-1).contiguous()
N = logps_flat.shape[0]
BLOCK_N = 1024
n_blocks = triton.cdiv(N, BLOCK_N)
partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
_fused_reinforce_loss_kernel[(n_blocks,)](
logps_flat,
advs_flat,
mask_flat,
partial_sum,
partial_cnt,
N=N,
BLOCK_N=BLOCK_N,
)
total_sum = partial_sum.sum()
total_cnt = partial_cnt.sum().clamp(min=1)
return total_sum / total_cnt
# ---------------------------------------------------------------------------
# 3. Fused cosine similarity (batched, for EBFT rewards)
# ---------------------------------------------------------------------------
# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
# we fuse the dot product, norm computation, and division into one kernel.
@triton.jit
def _fused_cosine_sim_kernel(
a_ptr, # (N, D) first set of vectors
b_ptr, # (N, D) second set of vectors
out_ptr, # (N,) cosine similarities
D: tl.constexpr,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
a_row_ptr = a_ptr + row * D
b_row_ptr = b_ptr + row * D
dot = 0.0
norm_a = 0.0
norm_b = 0.0
for d_start in range(0, D, BLOCK_D):
d_offsets = d_start + tl.arange(0, BLOCK_D)
mask = d_offsets < D
a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
dot += tl.sum(a_vals * b_vals, axis=0)
norm_a += tl.sum(a_vals * a_vals, axis=0)
norm_b += tl.sum(b_vals * b_vals, axis=0)
denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
denom = tl.where(denom > 1e-8, denom, 1e-8)
result = dot / denom
tl.store(out_ptr + row, result)
def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Compute cosine similarity along the last dimension.
Args:
a, b: (..., D) tensors of the same shape
Returns:
(...,) tensor of cosine similarities
"""
orig_shape = a.shape[:-1]
D = a.shape[-1]
a_2d = a.reshape(-1, D).contiguous()
b_2d = b.reshape(-1, D).contiguous()
N = a_2d.shape[0]
output = torch.empty(N, device=a.device, dtype=torch.float32)
BLOCK_D = min(triton.next_power_of_2(D), 4096)
_fused_cosine_sim_kernel[(N,)](
a_2d,
b_2d,
output,
D=D,
BLOCK_D=BLOCK_D,
)
return output.view(orig_shape)
# ---------------------------------------------------------------------------
# 4. Fused pairwise diversity penalty
# ---------------------------------------------------------------------------
# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
# We compute the pairwise dot products and exclusion in one kernel.
@triton.jit
def _fused_diversity_kernel(
emb_ptr, # (B, N, D) embeddings, row-major
out_ptr, # (B, N) diversity penalties
N: tl.constexpr, # n_samples
D: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""For each (b, i), compute mean dot product to all j != i."""
b = tl.program_id(0)
i = tl.program_id(1)
# Pointer to emb[b, i, :]
emb_bi_ptr = emb_ptr + (b * N + i) * D
total_sim = 0.0
for j in range(N):
emb_bj_ptr = emb_ptr + (b * N + j) * D
dot = 0.0
for d_start in range(0, D, BLOCK_D):
d_offsets = d_start + tl.arange(0, BLOCK_D)
d_mask = d_offsets < D
a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
tl.float32
)
b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
tl.float32
)
dot += tl.sum(a_vals * b_vals, axis=0)
# Exclude self-similarity (j == i)
is_other = j != i
total_sim += dot * is_other
result = total_sim / (N - 1)
tl.store(out_ptr + b * N + i, result)
def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
"""Compute mean pairwise dot product (excluding self) per sample.
Args:
embeddings: (B, N, D) tensor where N is n_samples
Returns:
(B, N) tensor of diversity penalties
"""
B, N, D = embeddings.shape
embeddings = embeddings.contiguous()
output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
if N <= 1:
return output # diversity is 0 when there's only one sample
BLOCK_D = min(triton.next_power_of_2(D), 4096)
_fused_diversity_kernel[(B, N)](
embeddings,
output,
N=N,
D=D,
BLOCK_D=BLOCK_D,
)
return output

View File

@@ -0,0 +1,264 @@
"""
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@torch.no_grad()
def extract_hidden_states(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
layer_indices: list[int],
batch_size: int | None = None,
) -> torch.Tensor:
"""
Forward pass through model, extracting and concatenating hidden states
at specified layer indices.
Args:
model: The frozen feature network
input_ids: (B, S) token ids
attention_mask: (B, S) attention mask
layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
batch_size: If set, process in chunks to reduce peak memory
Returns:
Concatenated hidden states: (B, S, num_layers * H)
"""
if batch_size is None:
batch_size = input_ids.shape[0]
# Use the inner transformer body (skips lm_head) when available.
# This avoids the expensive hidden_dim × vocab_size matmul whose
# output (logits) is never used — only hidden_states are needed.
body = getattr(model, "model", None)
if body is not None and hasattr(body, "forward"):
forward_model = body
else:
forward_model = model
all_features = []
for i in range(0, input_ids.shape[0], batch_size):
chunk_ids = input_ids[i : i + batch_size]
chunk_mask = attention_mask[i : i + batch_size]
outputs = forward_model(
chunk_ids,
attention_mask=chunk_mask,
output_hidden_states=True,
return_dict=True,
)
# hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
# index 0 is the embedding layer output
hidden_states = outputs.hidden_states
chunk_features = []
for idx in layer_indices:
chunk_features.append(hidden_states[idx])
# Concatenate across feature dimension: (B, S, num_layers * H)
all_features.append(torch.cat(chunk_features, dim=-1))
return torch.cat(all_features, dim=0)
def apply_embed_method(
hidden_states: torch.Tensor,
method: str,
attention_mask: torch.Tensor | None = None,
prompt_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Pool per-token hidden states into per-sequence embeddings.
Args:
hidden_states: (B, S, D) concatenated hidden states
method: One of "last_token", "mean_pooling", "completion_mean", "concat"
attention_mask: (B, S) mask for mean pooling
prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
Returns:
Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
(B, 3*D) for concat
"""
if method == "last_token":
if attention_mask is not None:
# Find last non-padding position per sample
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
return hidden_states[:, -1, :]
if method == "mean_pooling":
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
return hidden_states.mean(dim=1)
if method == "completion_mean":
# Mean pool over completion tokens only (exclude prompt)
if prompt_lengths is None:
raise ValueError("completion_mean requires prompt_lengths")
B, S, _ = hidden_states.shape
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
if attention_mask is not None:
comp_mask = comp_mask & attention_mask.bool()
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
if method == "concat":
B, S, D = hidden_states.shape
if attention_mask is not None:
valid_lens = attention_mask.sum(dim=1).long() # (B,)
else:
valid_lens = torch.full(
(B,), S, device=hidden_states.device, dtype=torch.long
)
# Compute quartile positions relative to valid length per sample
# First valid position index for each sample (handles right-padding)
q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
batch_idx = torch.arange(B, device=hidden_states.device)
return torch.cat(
[
hidden_states[batch_idx, q1],
hidden_states[batch_idx, q2],
hidden_states[batch_idx, q3],
],
dim=-1,
)
raise ValueError(f"Unknown embed_method: {method}")
@torch.no_grad()
def get_alignment_rewards(
gen_embedding: torch.Tensor,
gt_embedding: torch.Tensor,
) -> torch.Tensor:
"""
Compute alignment reward as cosine similarity between generated
and ground-truth feature embeddings.
Args:
gen_embedding: (B, D) generated sequence embeddings
gt_embedding: (B, D) ground-truth sequence embeddings
If num_generations > 1, gt_embedding should be repeated
to match gen_embedding's batch dim.
Returns:
Alignment rewards: (B,) cosine similarities in [-1, 1]
"""
return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
@torch.no_grad()
def get_diversity_rewards(
gen_embedding: torch.Tensor,
num_generations: int,
) -> torch.Tensor:
"""
Compute diversity penalty as mean pairwise dot-product similarity
between samples from the same prompt (excluding self-similarity).
Args:
gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
num_generations: Number of generations per prompt
Returns:
Diversity penalties: (B,) mean similarity to other samples from same prompt
"""
if num_generations <= 1:
return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
num_prompts = gen_embedding.shape[0] // num_generations
# Reshape to (num_prompts, num_generations, D)
reshaped = gen_embedding.view(num_prompts, num_generations, -1)
# Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
# Zero out self-similarity (diagonal)
eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
# Mean similarity to other samples: (num_prompts, num_generations)
diversity = sims.sum(dim=-1) / (num_generations - 1)
# Flatten back to (B,)
return diversity.view(-1)
def whiten_embeddings_batched(
phi: torch.Tensor,
phi_gt: torch.Tensor,
whiten_tol: float = 1e-5,
normalize: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Whiten generated embeddings using SVD, then apply same transform to ground-truth.
Whitening decorrelates feature dimensions so no single direction dominates
the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
Note: Singular values scale with sqrt(B), so reward magnitudes are
batch-size dependent. This is acceptable because B = n_samples_per_prompt
which is fixed during training (typically 2-4).
Args:
phi: (B, D) generated embeddings (used to estimate covariance)
phi_gt: (B, D) ground-truth embeddings
whiten_tol: Tolerance for singular value cutoff
normalize: If True, L2-normalize after whitening
Returns:
Whitened (phi, phi_gt) tuple, each (B, D)
"""
phi_f = phi.float()
phi_gt_f = phi_gt.float()
# Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
try:
U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
except torch._C._LinAlgError:
# Fallback: add small noise
noise = 1e-6 * phi_f.abs().mean()
try:
U, S, _ = torch.linalg.svd(
(phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
full_matrices=False,
)
except torch._C._LinAlgError:
if normalize:
return (
F.normalize(phi, p=2, dim=-1),
F.normalize(phi_gt, p=2, dim=-1),
)
return phi, phi_gt
U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
# Safe inverse of singular values
s_max = S.max()
inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
# W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
if normalize:
phi_w = F.normalize(phi_w, p=2, dim=-1)
phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
return phi_w, phi_gt_w

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,531 @@
"""
EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
the standard GRPO reward function interface.
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
"""
import contextlib
import copy
from typing import TYPE_CHECKING, Any
import torch
from datasets import Dataset, IterableDataset
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
from axolotl.core.trainers.ebft.rewards import (
apply_embed_method,
extract_hidden_states,
get_alignment_rewards,
get_diversity_rewards,
whiten_embeddings_batched,
)
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOTrainer,
)
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from collections import defaultdict
from accelerate import Accelerator
from trl.generation.vllm_generation import VLLMGeneration
LOG = get_logger(__name__)
class EBFTMixin:
"""
Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
Provides:
- Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
- _feature_matching_reward() callable for GRPO reward function interface
- _sequential_rollout() for multi-turn conversations
"""
# Type stubs for attributes provided by the composed GRPOTrainer base class.
# These are not defined here but accessed via cooperative multiple inheritance.
if TYPE_CHECKING:
accelerator: Accelerator
model: PreTrainedModel
args: AxolotlEBFTConfig
processing_class: PreTrainedTokenizerBase
num_generations: int
vllm_generation: VLLMGeneration
_metrics: defaultdict
_tag_names = ["trl", "ebft", "axolotl"]
def __init__(
self,
model: str | PreTrainedModel,
args: AxolotlEBFTConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset
| IterableDataset
| dict[str, Dataset | IterableDataset]
| None = None,
processing_class: PreTrainedTokenizerBase | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: Any | None = None,
):
# Pass our feature-matching reward function to GRPOTrainer
# It will be called with (prompts, completions, **kwargs) where
# kwargs includes all extra dataset fields like "ground_truth"
super().__init__( # type: ignore[call-arg]
model=model,
reward_funcs=[self._feature_matching_reward],
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
assert args is not None
# --- Feature network setup ---
unwrapped = self.accelerator.unwrap_model(self.model)
# Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
unwrapped, "disable_adapter"
)
if self._share_feature_weights:
# Share weights: use actor's base model with adapters disabled.
# Saves a full model copy (~8 GB for 4B model).
self.feature_network = None
param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
LOG.info(
f"EBFT feature network shares actor weights (PEFT disable_adapter). "
f"Saving ~{param_gb:.1f} GB"
)
else:
LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
self.feature_network = copy.deepcopy(unwrapped)
for param in self.feature_network.parameters():
param.requires_grad = False
self.feature_network.eval()
# Compute layer indices from fractional depths
# Handle VLM models where num_hidden_layers is on text_config
config = unwrapped.config
if hasattr(config, "text_config") and hasattr(
config.text_config, "num_hidden_layers"
):
config = config.text_config
num_layers = config.num_hidden_layers
self.feature_layer_indices = [
int(frac * num_layers) for frac in args.ebft_feature_layers
]
LOG.info(
f"EBFT feature extraction from layers {self.feature_layer_indices} "
f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
)
if args.ebft_adaptive_max_tokens:
LOG.info(
f"EBFT adaptive max_tokens enabled "
f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
)
_adaptive_max_lock = None # initialized lazily
def _generate_only(self, inputs, rank0_only=False):
"""Override to set per-batch max_tokens based on ground-truth length.
Uses a lock to prevent race conditions in async mode where concurrent
BG threads could interleave mutations of max_completion_length.
"""
import threading
args = self.args
if (
args.ebft_adaptive_max_tokens
and hasattr(self, "vllm_generation")
and inputs
):
gt_texts = [
x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
]
if gt_texts:
gt_token_counts = [
len(self.processing_class.encode(gt, add_special_tokens=False))
for gt in gt_texts
]
multiplier = args.ebft_gt_length_multiplier
max_completion = self.vllm_generation.max_completion_length
adaptive_max = max(
min(int(c * multiplier), max_completion) for c in gt_token_counts
)
adaptive_max = max(adaptive_max, 64)
if self._adaptive_max_lock is None:
self._adaptive_max_lock = threading.Lock()
with self._adaptive_max_lock:
original = self.vllm_generation.max_completion_length
self.vllm_generation.max_completion_length = adaptive_max
try:
return super()._generate_only(inputs, rank0_only)
finally:
self.vllm_generation.max_completion_length = original
return super()._generate_only(inputs, rank0_only)
@torch.no_grad()
def _feature_matching_reward(
self,
prompts: list,
completions: list,
ground_truth: list[str] | None = None,
remaining_turns: list | None = None,
**kwargs,
) -> list[float]:
"""
Compute feature-matching rewards for generated completions.
This is called by GRPOTrainer's _generate_and_score_completions()
as a standard reward function. The `ground_truth` field comes from
the dataset via reward_kwargs.
For multi-turn conversations, `remaining_turns` contains the subsequent
user/assistant turn pairs. When present, we do sequential rollouts:
generate each assistant turn conditioned on history + previous generations,
then compute feature-matching rewards on the full generated conversation.
Args:
prompts: List of prompt strings/messages
completions: List of generated completion strings
ground_truth: List of reference completion strings (from dataset)
remaining_turns: List of remaining conversation turns after the
first assistant turn (for multi-turn rollouts)
Returns:
List of scalar rewards, one per completion
"""
if ground_truth is None:
LOG.warning("No ground_truth field in dataset — using zero rewards")
return [0.0] * len(prompts)
device = self.accelerator.device
args = self.args
num_gens = self.num_generations
# --- Multi-turn sequential rollout ---
# If remaining_turns is provided, generate subsequent assistant turns
# by calling vLLM for each turn, building up the full conversation.
if remaining_turns is not None and hasattr(self, "vllm_generation"):
completions = self._sequential_rollout(
prompts, completions, remaining_turns, num_gens
)
# --- Tokenize generated sequences: prompt + completion ---
gen_texts = []
gen_prompt_texts = []
for p, c in zip(prompts, completions, strict=True):
if isinstance(p, list):
prompt_text = self.processing_class.apply_chat_template(
p, tokenize=False, add_generation_prompt=True
)
else:
prompt_text = p
if isinstance(c, list):
comp_text = c[0].get("content", "") if c else ""
else:
comp_text = c
gen_texts.append(prompt_text + comp_text)
gen_prompt_texts.append(prompt_text)
gen_encoded = self.processing_class(
text=gen_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=getattr(self.args, "max_length", None)
or getattr(self.args, "max_seq_length", None)
or 2048,
add_special_tokens=False,
)
gen_ids = gen_encoded["input_ids"].to(device)
gen_mask = gen_encoded["attention_mask"].to(device)
# Compute prompt lengths for completion_mean pooling
gen_prompt_lengths = torch.tensor(
[
len(self.processing_class.encode(pt, add_special_tokens=False))
for pt in gen_prompt_texts
],
device=device,
)
# --- Tokenize ground-truth sequences: prompt + ground_truth ---
# For multi-turn (remaining_turns present), render the full GT conversation
# through the chat template to preserve role markers between turns.
gt_texts = []
gt_prompt_texts = []
for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
if i % num_gens != 0:
continue # Only need one GT per prompt group
if isinstance(p, list):
prompt_text = self.processing_class.apply_chat_template(
p, tokenize=False, add_generation_prompt=True
)
# Multi-turn: build full GT conversation with remaining turns
if remaining_turns is not None:
prompt_idx = i // num_gens
turns = (
remaining_turns[prompt_idx]
if prompt_idx < len(remaining_turns)
else []
)
if turns:
gt_conv = list(p) + [{"role": "assistant", "content": gt}]
gt_conv.extend(turns)
full_gt_text = self.processing_class.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
gt_texts.append(full_gt_text)
gt_prompt_texts.append(prompt_text)
continue
else:
prompt_text = p
gt_texts.append(prompt_text + gt)
gt_prompt_texts.append(prompt_text)
gt_encoded = self.processing_class(
text=gt_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=getattr(self.args, "max_length", None)
or getattr(self.args, "max_seq_length", None)
or 2048,
add_special_tokens=False,
)
gt_ids = gt_encoded["input_ids"].to(device)
gt_mask = gt_encoded["attention_mask"].to(device)
gt_prompt_lengths = torch.tensor(
[
len(self.processing_class.encode(pt, add_special_tokens=False))
for pt in gt_prompt_texts
],
device=device,
)
# --- Extract features from frozen feature network ---
# INVARIANT: disable_adapter() yields the unmodified base weights because
# _sync_peft_weights_no_merge and _sync_lora_adapter never call
# merge_adapter() — they compute merged weights as new tensors or save
# the adapter to filesystem. Base weights are never modified in-place.
if self._share_feature_weights:
unwrapped = self.accelerator.unwrap_model(self.model)
feature_ctx = unwrapped.disable_adapter()
else:
unwrapped = self.feature_network
feature_ctx = contextlib.nullcontext()
with feature_ctx:
was_training = unwrapped.training
unwrapped.eval()
gen_hidden = extract_hidden_states(
unwrapped, gen_ids, gen_mask, self.feature_layer_indices
)
gt_hidden = extract_hidden_states(
unwrapped, gt_ids, gt_mask, self.feature_layer_indices
)
if was_training:
unwrapped.train()
# --- Pool to sequence-level embeddings ---
gen_emb = apply_embed_method(
gen_hidden,
args.ebft_embed_method,
gen_mask,
prompt_lengths=gen_prompt_lengths,
)
gt_emb = apply_embed_method(
gt_hidden,
args.ebft_embed_method,
gt_mask,
prompt_lengths=gt_prompt_lengths,
)
# --- Optional whitening ---
batch_size = gen_emb.shape[0]
if args.ebft_use_whitening and batch_size > 1:
num_prompts = batch_size // num_gens
gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
whitened_gen_list = []
whitened_gt_list = []
for i in range(num_prompts):
w_gen, w_gt = whiten_embeddings_batched(
gen_reshaped[i], gt_emb[i : i + 1]
)
whitened_gen_list.append(w_gen)
whitened_gt_list.append(w_gt)
gen_emb = torch.cat(whitened_gen_list, dim=0)
gt_emb = torch.cat(whitened_gt_list, dim=0)
else:
gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
# Repeat gt_emb: each GT repeated num_generations times
gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
# --- Compute rewards ---
alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
diversity = get_diversity_rewards(gen_emb, num_gens)
# Scale by 2 per paper equation (7):
# r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
alignment = alignment * 2
diversity = diversity * 2
rewards = (
args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
)
# Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
# Log feature-matching metrics to console and wandb
_align = alignment.mean().item()
_divers = diversity.mean().item()
_reward = rewards.mean().item()
_cfm = cfm_loss.item()
LOG.info(
f"ebft reward | "
f"align {_align:+.3f} ^ | "
f"divers {_divers:+.3f} v | "
f"cfm {_cfm:.3f} v | "
f"reward {_reward:+.3f} ^"
)
# Log to wandb via trainer's _metrics (picked up by GRPO's logging)
mode = "train" if self.model.training else "eval"
if hasattr(self, "_metrics"):
self._metrics[mode]["ebft/alignment"].append(_align)
self._metrics[mode]["ebft/diversity"].append(_divers)
self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
self._metrics[mode]["ebft/reward"].append(_reward)
return rewards.cpu().tolist()
@torch.no_grad()
def _sequential_rollout(
self,
prompts: list,
first_completions: list,
remaining_turns: list,
num_gens: int,
) -> list:
"""
Extend single-turn completions into multi-turn conversations.
For each prompt group, takes the first generated assistant turn and
sequentially generates subsequent assistant turns by calling vLLM,
building up a full multi-turn conversation.
Args:
prompts: List of prompt message lists (repeated num_gens times)
first_completions: List of generated first-turn completions
remaining_turns: List of remaining turn pairs after first assistant turn.
Each element is a list of dicts: [{"role": "user", "content": "..."},
{"role": "assistant", "content": "...GT..."}]
num_gens: Number of generations per prompt
Returns:
Extended completions incorporating all generated turns
"""
vllm_client = self.vllm_generation.vllm_client
max_tokens = getattr(self.args, "max_completion_length", 256)
temperature = getattr(self.args, "temperature", 0.7)
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
extended_completions = []
for idx in range(len(prompts)):
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
first_comp = first_completions[idx]
# Extract first completion text
if isinstance(first_comp, list):
first_text = first_comp[0].get("content", "") if first_comp else ""
else:
first_text = first_comp
# Get remaining turns for this prompt (same for all num_gens copies)
prompt_idx = idx // num_gens
turns = (
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
)
if not turns:
extended_completions.append(first_text)
continue
# Build conversation with generated first turn
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
# Generate subsequent turns
for turn in turns:
if turn["role"] == "user":
conv.append(turn)
elif turn["role"] == "assistant":
try:
result = vllm_client.chat(
messages=[conv],
n=1,
max_tokens=max_tokens,
temperature=temperature,
generation_kwargs=gen_kwargs,
)
gen_ids = result.get("completion_ids", [[]])[0]
gen_text = self.processing_class.decode(
gen_ids, skip_special_tokens=True
)
except Exception as e:
LOG.warning(f"Multi-turn rollout generation failed: {e}")
gen_text = ""
conv.append({"role": "assistant", "content": gen_text})
# Render full conversation through chat template, then extract
# everything after the original prompt as the "completion" text.
# This preserves role markers and formatting between turns.
full_rendered = self.processing_class.apply_chat_template(
conv, tokenize=False, add_generation_prompt=False
)
prompt_rendered = self.processing_class.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
completion_text = full_rendered[len(prompt_rendered) :]
extended_completions.append(completion_text)
return extended_completions
class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
"""EBFT trainer using synchronous GRPO (standard vLLM generation)."""
pass
class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
"""EBFT trainer using async GRPO (prefetches next batch during training)."""
pass

View File

@@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration. # Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
# The communicator is not needed because weight sync happens via filesystem + HTTP, # merged weight sync. NCCL is only needed for the standard update_named_param
# and it fails when vLLM and a trainer rank share the same CUDA device. # path which broadcasts tensors through the communicator.
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None) training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
if training_args is not None and getattr( _skip_nccl = False
training_args, "vllm_lora_sync", False if training_args is not None:
): if getattr(training_args, "vllm_lora_sync", False):
_skip_nccl = True # LoRA sync uses filesystem + HTTP
elif getattr(training_args, "async_prefetch", False):
# Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
# init_communicator allocates device tensors on rank 0 only, which
# causes DDP to see different param counts across ranks.
# The communicator is initialized lazily on first weight sync instead.
_skip_nccl = True
if _skip_nccl:
from trl.generation.vllm_generation import VLLMGeneration from trl.generation.vllm_generation import VLLMGeneration
_orig_init_vllm = VLLMGeneration._init_vllm _orig_init_vllm = VLLMGeneration._init_vllm
@@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
VLLMGeneration._init_vllm = _init_vllm_no_communicator VLLMGeneration._init_vllm = _init_vllm_no_communicator
super().__init__(*args, **kwargs) try:
super().__init__(*args, **kwargs)
finally:
# Restore original _init_vllm so other trainers aren't affected
if _skip_nccl:
VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
# FP8 models: zero out the pad token embedding so that padding # FP8 models: zero out the pad token embedding so that padding
# positions have zero hidden states throughout the network. # positions have zero hidden states throughout the network.
@@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._executor = None self._executor = None
def _submit_generation(self): def _submit_generation(self):
"""Submit the next background generation job.""" """Submit the next background generation job.
With multi-process (DDP/FSDP), only rank 0 generates to avoid
cross-rank NCCL collectives from background threads. Non-rank-0
processes enqueue a sentinel ``None`` that is replaced by a
broadcast in ``_prepare_inputs_legacy_async``.
"""
rank0_only = self.accelerator.num_processes > 1
if rank0_only and not self.accelerator.is_main_process:
# Non-rank-0: nothing to generate; enqueue a resolved None future
f: concurrent.futures.Future = concurrent.futures.Future()
f.set_result(None)
self._async_queue.put(f)
return
batch = next(self._prompt_iter) batch = next(self._prompt_iter)
future = self._executor.submit(self._generate_only, batch) future = self._executor.submit(self._generate_only, batch, rank0_only)
self._async_queue.put(future) self._async_queue.put(future)
# ------------------------------------------------------------------
# Broadcast rollout (legacy async, multi-process)
# ------------------------------------------------------------------
def _broadcast_rollout(self, rollout: dict | None) -> dict:
"""Broadcast a rank0-only rollout dict to all ranks (main thread).
Rank 0 has the full rollout dict from ``_generate_only``; other ranks
have ``None``. After broadcast, tensors are moved to each rank's
local device.
"""
import torch.distributed as dist
obj_list = [rollout if self.accelerator.is_main_process else None]
dist.broadcast_object_list(obj_list, src=0)
rollout = obj_list[0]
assert rollout is not None, "broadcast_object_list failed to deliver rollout"
# Move tensors to local device (broadcast deserializes to CPU)
device = self.accelerator.device
for key, val in rollout.items():
if isinstance(val, torch.Tensor) and val.device != device:
rollout[key] = val.to(device)
return rollout
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Weight sync # Weight sync
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer):
for Float8), and also safe for concurrent use since it never modifies base for Float8), and also safe for concurrent use since it never modifies base
weights in-place. weights in-place.
""" """
model = self.vllm_generation.model
accelerator = self.vllm_generation.accelerator accelerator = self.vllm_generation.accelerator
vllm_client = self.vllm_generation.vllm_client
fix_name = self.vllm_generation._fix_param_name_to_vllm
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process): if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
return return
# In multi-GPU async mode, we skip NCCL communicator init to avoid
# DDP param count mismatch and NCCL device conflicts. Weight sync
# uses the HTTP-only fallback in batch_update_named_params instead.
model = self.vllm_generation.model
vllm_client = self.vllm_generation.vllm_client
fix_name = self.vllm_generation._fix_param_name_to_vllm
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers # Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
lora_info = {} lora_info = {}
for mod_name, module in model.base_model.model.named_modules(): for mod_name, module in model.base_model.model.named_modules():
@@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
weight_name = pname.replace(".weight_scale_inv", ".weight") weight_name = pname.replace(".weight_scale_inv", ".weight")
scale_inv_lookup[weight_name] = pparam.data scale_inv_lookup[weight_name] = pparam.data
# Iterate all parameters, computing merged weights for LoRA layers. # Only sync parameters that have LoRA modifications — skip unchanged
# Skip LoRA-specific params and FP8 scale params (scales will be # base weights to avoid OOM on the vLLM GPU from allocating the entire
# recomputed by vLLM when it receives the merged bf16 weight). # model's worth of NCCL receive buffers.
params_to_sync = [] params_to_sync = []
compute_dtype = torch.bfloat16
for name, param in model.named_parameters(): for name, param in model.named_parameters():
vllm_name = name.removeprefix("base_model.model.").replace( vllm_name = name.removeprefix("base_model.model.").replace(
".base_layer", "" ".base_layer", ""
@@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer):
continue continue
if "original_module" in vllm_name: if "original_module" in vllm_name:
continue continue
# Skip FP8 quantization scale parameters - they are recomputed
# on the vLLM side when we update the weight itself
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name: if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
continue continue
if not vllm_name.endswith(".weight"):
continue
# fix_name strips modules_to_save.default. prefix
raw_mod_path = vllm_name[: -len(".weight")]
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."]) vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
mod_path = vllm_name[: -len(".weight")]
# Sync weights that have LoRA adapters OR are modules_to_save
is_lora = mod_path in lora_info
is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
if not is_lora and not is_modules_to_save:
continue
data = param.data data = param.data
compute_dtype = torch.bfloat16
if vllm_name.endswith(".weight"): # Dequantize FP8 weights before merging
# Dequantize FP8 weights before merging if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup: scale_inv = scale_inv_lookup[name]
scale_inv = scale_inv_lookup[name] fp8_bf16 = data.to(compute_dtype)
# Block dequantization: weight * scale_inv (with broadcasting) if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
fp8_bf16 = data.to(compute_dtype) sr, sc = scale_inv.shape
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2: br = fp8_bf16.shape[0] // sr
# Block-quantized: scale_inv shape (rows/block, cols/block) bc = fp8_bf16.shape[1] // sc
sr, sc = scale_inv.shape data = (
br = fp8_bf16.shape[0] // sr # block height fp8_bf16.reshape(sr, br, sc, bc)
bc = fp8_bf16.shape[1] // sc # block width * scale_inv[:, None, :, None].to(compute_dtype)
# Reshape → multiply by block scale → reshape back ).reshape(fp8_bf16.shape)
data = ( elif scale_inv.dim() <= 1:
fp8_bf16.reshape(sr, br, sc, bc) data = fp8_bf16 * scale_inv.to(compute_dtype)
* scale_inv[:, None, :, None].to(compute_dtype) else:
).reshape(fp8_bf16.shape) data = fp8_bf16
elif scale_inv.dim() <= 1: elif data.dtype == torch.float8_e4m3fn:
# Per-tensor or per-channel scale data = data.to(compute_dtype)
data = fp8_bf16 * scale_inv.to(compute_dtype)
else:
data = fp8_bf16
elif data.dtype == torch.float8_e4m3fn:
# FP8 but no scale found - just cast (lossy)
data = data.to(compute_dtype)
mod_path = vllm_name[: -len(".weight")] if is_lora:
if mod_path in lora_info: A, B, s = lora_info[mod_path]
A, B, s = lora_info[mod_path] merged = data.to(compute_dtype) + s * (
merged = data.to(compute_dtype) + s * ( B.to(compute_dtype) @ A.to(compute_dtype)
B.to(compute_dtype) @ A.to(compute_dtype) )
) params_to_sync.append((vllm_name, merged))
data = merged else:
# modules_to_save: send raw weight (no LoRA merge needed)
params_to_sync.append((vllm_name, data.to(compute_dtype)))
params_to_sync.append((vllm_name, data)) # Batch sync only LoRA-modified params via HTTP+NCCL
# Batch sync all params in one HTTP+NCCL call (vs individual calls)
if params_to_sync: if params_to_sync:
sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
logger.info(
f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
)
vllm_client.batch_update_named_params(params_to_sync) vllm_client.batch_update_named_params(params_to_sync)
# Reset prefix cache after weight update # Reset prefix cache after weight update
@@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
vllm_client = self.vllm_generation.vllm_client vllm_client = self.vllm_generation.vllm_client
url = f"{vllm_client.base_url}/set_lora_adapter/" url = f"{vllm_client.base_url}/set_lora_adapter/"
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
response = requests.post( response = requests.post(
url, url,
json={ json={
@@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
"lora_int_id": self._lora_sync_version, "lora_int_id": self._lora_sync_version,
"lora_path": adapter_path, "lora_path": adapter_path,
}, },
timeout=30, timeout=sync_timeout,
) )
if response.status_code != 200: if response.status_code != 200:
logger.warning( logger.warning(
@@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
step = self.state.global_step step = self.state.global_step
interval = self.args.vllm_sync_interval interval = self.args.vllm_sync_interval
if step != self._last_synced_step and step % interval == 0: if step != self._last_synced_step and step % interval == 0:
if step == 0:
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
self._last_synced_step = step
return
if getattr(self.args, "vllm_lora_sync", False): if getattr(self.args, "vllm_lora_sync", False):
if step == 0:
logger.info("Skipping LoRA sync at step 0 (no training yet)")
self._last_synced_step = step
return
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly # Native LoRA sync: save adapter to filesystem, vLLM loads it directly
self._sync_lora_adapter() self._sync_lora_adapter()
else: else:
@@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Background-thread generation (no scoring) # Background-thread generation (no scoring)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _generate_single_turn(self, prompts, **kwargs): def _generate_single_turn(self, prompts, *args, **kwargs):
"""Override to prevent weight sync from background thread and to use """Override to prevent weight sync from background thread and to use
no-merge sync for PEFT models (FP8 models can't merge_adapter).""" no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
is_bg = threading.current_thread() is not threading.main_thread() is_bg = threading.current_thread() is not threading.main_thread()
@@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._patched_sync_weights = True self._patched_sync_weights = True
try: try:
return super()._generate_single_turn(prompts, **kwargs) return super()._generate_single_turn(prompts, *args, **kwargs)
finally: finally:
if saved_step is not None: if saved_step is not None:
self._last_loaded_step = saved_step self._last_loaded_step = saved_step
@@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
output = vg.vllm_client.chat( output = vg.vllm_client.chat(
messages=unique_prompts, messages=unique_prompts,
**sampling_params, **sampling_params,
chat_template_kwargs=vg.chat_template_kwargs, chat_template_kwargs=self.chat_template_kwargs,
tools=vg.tools, tools=self.tools,
chat_template=vg.chat_template, chat_template=getattr(self, "chat_template", None),
) )
else: else:
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params) output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
@@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
logps_diff = per_token_logps_diff logps_diff = per_token_logps_diff
is_ratio = torch.exp(logps_diff) is_ratio = torch.exp(logps_diff)
is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
if is_mode in ("sequence_truncate", "token_truncate"): if is_mode in ("sequence_truncate", "token_truncate"):
is_ratio = torch.clamp(is_ratio, max=is_cap) is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"): elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
is_ratio = is_ratio.clamp(min=is_floor)
data["importance_sampling_ratio"] = is_ratio data["importance_sampling_ratio"] = is_ratio
# --- Collect rewards (launched before logprobs, should be done) --- # --- Collect rewards (launched before logprobs, should be done) ---
@@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
seq_is = is_mode in ("sequence_mask", "sequence_truncate") seq_is = is_mode in ("sequence_mask", "sequence_truncate")
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
is_ratio = torch.exp(logps_diff) is_ratio = torch.exp(logps_diff)
# Symmetric floor clamp (matches non-streaming path at line ~1651)
is_floor = 1.0 / is_cap
if is_mode in ("sequence_truncate", "token_truncate"): if is_mode in ("sequence_truncate", "token_truncate"):
is_ratio = torch.clamp(is_ratio, max=is_cap) is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"): elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
is_ratio = is_ratio.clamp(min=is_floor)
if "importance_sampling_ratio" not in data: if "importance_sampling_ratio" not in data:
total = len(data["prompt_ids"]) total = len(data["prompt_ids"])
shape = (total, 1) if seq_is else (total, is_ratio.size(1)) shape = (total, 1) if seq_is else (total, is_ratio.size(1))
@@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
rollout = future.result() rollout = future.result()
self._submit_generation() self._submit_generation()
# With multi-process, only rank 0 generated. Broadcast to all ranks.
if self.accelerator.num_processes > 1:
rollout = self._broadcast_rollout(rollout)
if self.args.streaming_partial_batch: if self.args.streaming_partial_batch:
micro_batches = self._score_streaming(rollout) micro_batches = self._score_streaming(rollout)
else: else:

View File

@@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback):
logger.info("=" * 60) logger.info("=" * 60)
if self.trainer.axolotl_cfg.use_wandb: if self.trainer.axolotl_cfg.use_wandb:
if wandb.run is not None: if wandb.run is not None: # type: ignore[attr-defined]
wandb.log( wandb.log( # type: ignore[attr-defined]
{ {
"generated_samples": wandb.Table( "generated_samples": wandb.Table( # type: ignore[attr-defined]
columns=[ columns=[
"step", "step",
"original", "original",

View File

@@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__)
def _batch_update_named_params( def _batch_update_named_params(
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
): ):
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL.""" """Batched weight sync — uses NCCL if communicator available, HTTP otherwise."""
from transformers import is_torch_xpu_available has_communicator = getattr(self, "communicator", None) is not None
if chunk_size is None: if has_communicator:
chunks = [params] # Fast path: metadata via HTTP, tensors via NCCL
else: from transformers import is_torch_xpu_available
chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks: if chunk_size is None:
param_metadata = [ chunks = [params]
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
else: else:
self.communicator.group.barrier() chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
param_metadata = [
{
"name": name,
"dtype": str(weights.dtype),
"shape": list(weights.shape),
}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(
url, json={"params": param_metadata}, timeout=120
)
if response.status_code != 200:
raise Exception(
f"Request failed: {response.status_code}, {response.text}"
)
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
else:
# HTTP-only path: encode tensor data in request body (no NCCL needed).
# Batch by byte size to avoid huge HTTP payloads.
MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB
HTTP_TIMEOUT = 120 # seconds per request
payload: list[dict] = []
payload_bytes = 0
url = f"{self.base_url}/http_update_weights/"
def _flush(p: list[dict]) -> None:
if not p:
return
response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT)
if response.status_code != 200:
raise Exception(
f"Request failed: {response.status_code}, {response.text}"
)
from axolotl.utils.weight_serde import encode_for_http
for name, weights in params:
entry = encode_for_http(name, weights)
entry_bytes = weights.nelement() * weights.element_size()
# Flush current batch if adding this entry would exceed limit
if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST:
_flush(payload)
payload = []
payload_bytes = 0
payload.append(entry)
payload_bytes += entry_bytes
_flush(payload) # send remaining
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None): def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):

View File

@@ -0,0 +1,9 @@
"""
module for EBFT style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.ebft")

View File

@@ -0,0 +1,129 @@
"""
Dataset transform for multi-turn chat data with structured EBFT (vLLM mode).
Three variants:
1. `transform` — Uses the FIRST assistant turn as the generation target.
Passes remaining turns as `remaining_turns` for sequential rollout.
The trainer generates turn 1 via GRPO/vLLM, then sequentially generates
subsequent assistant turns, comparing the full conversation to GT.
2. `transform_last_turn` — Uses the LAST assistant turn as the target.
Simplest approach: the full conversation history is the prompt.
3. `transform_all_turns` — Explodes each conversation into N examples
(one per assistant turn). Each turn is an independent training example.
Use with batched=True.
Supports OpenAI chat format:
{"messages": [{"role": ..., "content": ...}, ...]}
"""
def transform(cfg, **kwargs):
"""Multi-turn with sequential rollout.
Returns the first assistant turn as ground_truth, plus remaining_turns
for the trainer to do sequential rollout generation.
"""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if not messages:
return {"prompt": [], "ground_truth": ""}
# Split at first assistant turn
prompt_msgs = []
first_gt = None
remaining = []
found_first = False
for msg in messages:
if msg["role"] == "assistant" and not found_first:
first_gt = msg["content"]
found_first = True
elif found_first:
remaining.append(msg)
else:
prompt_msgs.append(msg)
if first_gt is None:
return {"prompt": prompt_msgs, "ground_truth": ""}
# Store only the first assistant turn as ground_truth. The full multi-turn
# GT is reconstructed in the reward function via chat template rendering
# (using remaining_turns), which preserves role markers between turns.
return {
"prompt": prompt_msgs,
"ground_truth": first_gt,
"remaining_turns": remaining,
}
return transform_fn, {
"remove_columns": "__all__",
}
def transform_last_turn(cfg, **kwargs):
"""Single-turn: use the last assistant turn as the generation target."""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if not messages:
return {"prompt": [], "ground_truth": ""}
# Find all assistant turns
history = []
last_prompt = []
last_gt = ""
for msg in messages:
if msg["role"] == "assistant":
last_prompt = list(history)
last_gt = msg["content"]
history.append(msg)
return {
"prompt": last_prompt,
"ground_truth": last_gt,
}
return transform_fn, {
"remove_columns": "__all__",
}
def transform_all_turns(cfg, **kwargs):
"""Explode: one example per assistant turn.
Use with datasets.map(batched=True) to produce N examples from
each N-turn conversation.
Usage in YAML:
type: ebft_chat_multiturn.transform_all_turns
"""
def transform_fn(examples, tokenizer=None):
all_prompts = []
all_ground_truths = []
messages_list = examples.get("messages", examples.get("conversations", []))
for messages in messages_list:
history = []
for msg in messages:
if msg["role"] == "assistant":
all_prompts.append(list(history))
all_ground_truths.append(msg["content"])
history.append(msg)
return {
"prompt": all_prompts,
"ground_truth": all_ground_truths,
}
return transform_fn, {
"remove_columns": "__all__",
"batched": True,
}

View File

@@ -0,0 +1,20 @@
"""
Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode.
Maps the dataset's `input` (prompt) and `output` (code solution) fields
to the format expected by the EBFT trainer (prompt + ground_truth).
"""
def transform(cfg, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "user", "content": example["input"]},
],
"ground_truth": example["output"],
}
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,319 @@
"""
Dataset transform for reasoning/thinking datasets with EBFT.
Handles datasets where assistant responses contain <think>...</think> reasoning
traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs).
Two variants:
1. `transform` — For structured EBFT (vLLM mode):
Returns prompt + ground_truth with thinking tags preserved.
Feature matching compares full responses (thinking + answer).
2. `transform_answer_only` — For structured EBFT (vLLM mode):
Strips <think>...</think> from ground_truth, so feature matching
only scores the final answer portion. Use when reasoning chains
can vary but the answer should match.
3. `transform_strided` — For strided EBFT:
Tokenizes the full conversation with thinking traces.
Optionally masks thinking tokens from CE loss (labels=-100 for think spans)
while still placing anchors in thinking regions for feature matching.
All variants work with OpenAI chat format:
{"messages": [{"role": "...", "content": "<think>...</think>Answer"}]}
"""
import re
def _strip_thinking(text: str) -> str:
"""Remove <think>...</think> blocks from text."""
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
def _extract_thinking(text: str) -> tuple[str, str]:
"""Split text into (thinking, answer) parts."""
match = re.search(r"<think>(.*?)</think>\s*(.*)", text, flags=re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
return "", text.strip()
def transform(cfg, **kwargs):
"""Full response including thinking traces for feature matching.
For datasets where assistant content has <think>...</think> tags in the
content field. The ground_truth includes the full content (thinking + answer).
"""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
prompt_msgs_snapshot = None
ground_truth = ""
for msg_idx, msg in enumerate(messages):
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(messages[:msg_idx])
ground_truth = msg["content"]
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_split_thinking(cfg, **kwargs):
"""Split <think> tags into reasoning_content field for native chat template handling.
For datasets where thinking is embedded in the content field as <think>...</think>.
Splits it into separate reasoning_content and content fields so the model's
chat template can format it natively (e.g., Qwen3.5's reasoning_content support).
The prompt messages are passed through with reasoning_content properly split,
so vLLM generation with enable_thinking=true produces comparable outputs.
The ground_truth is the full assistant response (thinking + answer) for
feature matching.
Also works for:
- <reasoning>...</reasoning> tags
- <|begin_of_thought|>...<|end_of_thought|> tags
"""
_THINKING_PAIRS = [
("<think>", "</think>"),
("<reasoning>", "</reasoning>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
]
def _split_msg_thinking(msg):
"""Split thinking from assistant message content into reasoning_content.
Always includes reasoning_content key on assistant messages (empty string
if no thinking tags found) to ensure consistent HF dataset schema across
all examples in a batch.
"""
if msg["role"] != "assistant":
return msg
content = msg.get("content", "")
# Already has reasoning_content — pass through
if "reasoning_content" in msg:
return msg
for open_tag, close_tag in _THINKING_PAIRS:
if open_tag in content and close_tag in content:
start = content.find(open_tag)
end = content.find(close_tag)
thinking = content[start + len(open_tag) : end].strip()
answer = content[end + len(close_tag) :].strip()
return {
**msg,
"reasoning_content": thinking,
"content": answer,
}
# No thinking tags — still add reasoning_content for schema consistency
return {**msg, "reasoning_content": ""}
def _normalize_msg(msg):
"""Ensure every message has {role, content, reasoning_content} for HF schema consistency."""
return {
"role": msg.get("role", ""),
"content": msg.get("content", ""),
"reasoning_content": msg.get("reasoning_content", ""),
}
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
# Split thinking in all assistant messages, then normalize schema
split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages]
# Build prompt (all messages except last assistant) and ground_truth
prompt_msgs = []
prompt_msgs_snapshot = None
ground_truth = ""
for msg in split_messages:
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(prompt_msgs)
# ground_truth is the FULL content for feature matching
thinking = msg.get("reasoning_content", "")
answer = msg.get("content", "")
if thinking:
ground_truth = f"<think>\n{thinking}\n</think>\n\n{answer}"
else:
ground_truth = answer
prompt_msgs.append(msg)
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else split_messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_answer_only(cfg, **kwargs):
"""Strip thinking from ground_truth — match features on answer only."""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
prompt_msgs = []
prompt_msgs_snapshot = None
ground_truth = ""
for msg in messages:
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(prompt_msgs)
ground_truth = _strip_thinking(msg["content"])
prompt_msgs.append(msg)
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_strided(cfg, **kwargs):
"""For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss.
Config options (via cfg):
- ebft.mask_thinking_ce: bool (default False)
If True, set labels=-100 for tokens inside <think>...</think> blocks.
Feature matching still uses these positions (anchors are placed everywhere
in the completion span). Only CE auxiliary loss is affected.
"""
seq_len = cfg.sequence_len
mask_thinking = False
if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"):
mask_thinking = cfg.ebft.mask_thinking_ce
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if tokenizer is None:
for m in messages:
if m.get("role") == "user":
return {"prompt": m["content"]}
return {"prompt": str(messages)}
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
# Tokenize the full conversation with the chat template
full_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
full_enc = tokenizer(
full_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)
input_ids = full_enc["input_ids"]
# Build labels: -100 for non-assistant tokens
labels = [-100] * len(input_ids)
# Find assistant turn boundaries using incremental tokenization.
# Only the FINAL assistant turn is marked as trainable.
prefix_messages = []
final_start = None
final_end = None
for msg in messages:
if msg["role"] == "assistant":
prefix_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=True,
)
prefix_ids = tokenizer(
prefix_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
start = len(prefix_ids)
prefix_messages.append(msg)
with_turn_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=False,
)
with_turn_ids = tokenizer(
with_turn_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
end = len(with_turn_ids)
# Record this turn's boundaries; only the last one will be used
final_start = start
final_end = end
else:
prefix_messages.append(msg)
# Mark only the final assistant turn as trainable
if final_start is not None and final_end is not None:
for i in range(final_start, min(final_end, len(labels))):
labels[i] = input_ids[i]
# Optionally mask <think>...</think> tokens within this turn.
# Find think spans by scanning for <think> and </think> token IDs
# directly in the input_ids (robust to tokenization alignment).
if mask_thinking:
think_open_id = tokenizer.convert_tokens_to_ids("<think>")
think_close_id = tokenizer.convert_tokens_to_ids("</think>")
if think_open_id != tokenizer.unk_token_id:
# Scan from before the assistant turn start to catch
# <think> tags that are part of the template prefix
scan_start = max(0, final_start - 5)
in_think = False
for i in range(scan_start, min(final_end, len(labels))):
if input_ids[i] == think_open_id:
in_think = True
if in_think and i >= final_start:
labels[i] = -100
if input_ids[i] == think_close_id:
in_think = False
if i >= final_start:
labels[i] = -100
# Derive prompt_length
prompt_length = len(input_ids)
for i, lbl in enumerate(labels):
if lbl != -100:
prompt_length = i
break
# Pad
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
return transform_fn, {"remove_columns": "__all__"}

View File

@@ -0,0 +1,110 @@
"""
Dataset transform for multi-turn chat data with strided EBFT.
Tokenizes conversations using the model's chat template, producing input_ids
with labels=-100 for system/user turns and real labels for assistant turns.
The strided trainer places anchors only within assistant completion spans.
Works with datasets in OpenAI chat format:
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
"""
def transform(cfg, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if tokenizer is None:
# For preview: just return the first user message
for m in messages:
if m.get("role") == "user":
return {"prompt": m["content"]}
return {"prompt": str(messages)}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize the full conversation with the chat template
full_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
full_enc = tokenizer(
full_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)
input_ids = full_enc["input_ids"]
# Build labels: -100 for everything except assistant turns.
# Strategy: tokenize incrementally to find assistant turn boundaries.
labels = [-100] * len(input_ids)
# Tokenize prefix up to each assistant turn to find boundaries
prefix_messages = []
for msg in messages:
if msg["role"] == "assistant":
# Tokenize prefix (everything before this assistant turn + generation prompt)
prefix_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=True,
)
prefix_ids = tokenizer(
prefix_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
start = len(prefix_ids)
# Tokenize prefix + this assistant turn
prefix_messages.append(msg)
with_turn_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=False,
)
with_turn_ids = tokenizer(
with_turn_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
end = len(with_turn_ids)
# Mark assistant tokens as trainable
for i in range(start, min(end, len(labels))):
labels[i] = input_ids[i]
else:
prefix_messages.append(msg)
# Derive prompt_length as the position of the first non-masked label
prompt_length = len(input_ids) # default: all masked
for i, lbl in enumerate(labels):
if lbl != -100:
prompt_length = i
break
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,80 @@
"""
Dataset transform for structured (prompt, completion) data with strided EBFT.
Tokenizes prompt and completion separately, concatenates into a single
input_ids sequence, and marks prompt tokens with labels=-100 so the
strided trainer knows where to place anchors (completion span only).
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
"""
def transform(cfg, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
# Extract prompt and completion from the example
prompt_text = example.get(
"input", example.get("prompt", example.get("question", ""))
)
completion_text = example.get(
"output", example.get("completion", example.get("answer", ""))
)
if tokenizer is None:
return {"prompt": prompt_text}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize prompt and completion separately
prompt_enc = tokenizer(
prompt_text,
truncation=False,
add_special_tokens=True,
return_tensors=None,
)
completion_enc = tokenizer(
completion_text,
truncation=False,
add_special_tokens=False,
return_tensors=None,
)
prompt_ids = prompt_enc["input_ids"]
completion_ids = completion_enc["input_ids"]
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
total_len = len(prompt_ids) + len(completion_ids)
if total_len > seq_len:
# Truncate completion first, then prompt if needed
max_completion = seq_len - len(prompt_ids)
if max_completion < 1:
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
prompt_ids = prompt_ids[: seq_len - 1]
completion_ids = completion_ids[:1]
else:
completion_ids = completion_ids[:max_completion]
input_ids = prompt_ids + completion_ids
prompt_length = len(prompt_ids)
# Labels: -100 for prompt tokens, input_ids for completion tokens
labels = [-100] * prompt_length + completion_ids
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
# Signal to remove all original columns (filtered to existing ones at map time)
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -241,6 +241,23 @@ def main(script_args: ScriptArguments):
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
# --- Access logging middleware ---
import time as _time
@app.middleware("http")
async def access_log_middleware(request, call_next):
t0 = _time.monotonic()
response = await call_next(request)
elapsed = _time.monotonic() - t0
logger.info(
"%s %s %d %.3fs",
request.method,
request.url.path,
response.status_code,
elapsed,
)
return response
# --- Active LoRA state (shared across endpoints via closure) --- # --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None} active_lora: dict = {"request": None}
@@ -300,7 +317,11 @@ def main(script_args: ScriptArguments):
import vllm import vllm
from packaging.version import Version from packaging.version import Version
from vllm.sampling_params import GuidedDecodingParams
try:
from vllm.sampling_params import GuidedDecodingParams
except ImportError:
GuidedDecodingParams = None # not available in vLLM 0.17+
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item] images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
prompts: list[dict[str, Any]] = [] prompts: list[dict[str, Any]] = []
@@ -362,7 +383,12 @@ def main(script_args: ScriptArguments):
} }
conn.send({"type": "call", "method": "generate", "kwargs": kwargs}) conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections] # Use run_in_executor so blocking recv() doesn't freeze the event loop
# (allows /set_lora_adapter/ and other endpoints to be served concurrently)
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
all_outputs = [ all_outputs = [
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
] ]
@@ -404,7 +430,10 @@ def main(script_args: ScriptArguments):
} }
conn.send({"type": "call", "method": "chat", "kwargs": kwargs}) conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections] loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs)) all_outputs = list(chain.from_iterable(all_outputs))
@@ -474,11 +503,51 @@ def main(script_args: ScriptArguments):
) )
return {"message": f"Batch update for {len(params_list)} params"} return {"message": f"Batch update for {len(params_list)} params"}
class HTTPWeightUpdateRequest(BaseModel):
"""Weight update via HTTP (no NCCL needed)."""
params: list[
dict
] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
@app.post("/http_update_weights/")
async def http_update_weights(request: HTTPWeightUpdateRequest):
"""Update model weights via HTTP — no NCCL communicator required.
Tensor data is sent as base64-encoded raw bytes in the request body.
Slower than NCCL for large models but works without cross-process setup.
"""
from axolotl.utils.weight_serde import (
decode_from_http,
encode_for_ipc,
)
weights_to_load = [decode_from_http(p) for p in request.params]
# Send all weights in a single IPC call. Tensors don't survive
# vLLM's multiproc IPC, so serialize as raw bytes + metadata.
param_entries = [
encode_for_ipc(name, weight) for name, weight in weights_to_load
]
kwargs = {
"method": "http_load_weights_batch",
"kwargs": {"params": param_entries},
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
@app.post("/reset_prefix_cache/") @app.post("/reset_prefix_cache/")
async def reset_prefix_cache(): async def reset_prefix_cache():
for conn in connections: for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"}) conn.send({"type": "call", "method": "reset_prefix_cache"})
results = [conn.recv() for conn in connections] loop = asyncio.get_running_loop()
results = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
return {"message": f"Reset prefix cache: {all(results)}"} return {"message": f"Reset prefix cache: {all(results)}"}
@app.post("/close_communicator/") @app.post("/close_communicator/")

View File

@@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
model = self.model_runner.model model = self.model_runner.model
params_dict = dict(model.named_parameters()) params_dict = dict(model.named_parameters())
# Handle VLM models where trainer and vLLM use different prefixes.
# Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..."
# vLLM (Qwen3.5): "language_model.model.layers.X..."
if name not in params_dict:
# Try common prefix remappings
for src_prefix, dst_prefix in [
("model.language_model.layers.", "language_model.model.layers."),
("model.layers.", "language_model.model.layers."),
]:
if name.startswith(src_prefix):
name = dst_prefix + name[len(src_prefix) :]
break
# Check if this is a simple direct param (exists as-is) # Check if this is a simple direct param (exists as-is)
if name in params_dict: if name in params_dict:
params_dict[name].data.copy_(weight.to(params_dict[name].dtype)) params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
@@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
return return
# Fallback: try load_weights (may work for non-stacked params) # Fallback: try load_weights (may work for non-stacked params)
logger.warning("Falling back to load_weights for param: %s", name) # Log the actual param names available for debugging
sample_keys = [
k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k
][:3]
logger.warning(
"Falling back to load_weights for param: %s (sample vLLM keys: %s)",
name,
sample_keys,
)
model.load_weights(weights=[(name, weight)]) model.load_weights(weights=[(name, weight)])
def update_named_param(self, name, dtype, shape): def update_named_param(self, name, dtype, shape):
@@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
# Load weights using direct set (handles stacked params) # Load weights using direct set (handles stacked params)
for name, weight in weights_to_load: for name, weight in weights_to_load:
self._direct_set_weight(name, weight) self._direct_set_weight(name, weight)
def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]):
"""Load weights received via HTTP (no NCCL needed)."""
for name, weight in weights:
self._direct_set_weight(name, weight.to(self.device))
def http_load_weight(self, **kwargs):
"""Load a single weight received via HTTP (no NCCL needed).
Reconstructs the tensor from raw bytes since tensors don't survive
vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights``
which handles TP sharding and stacked-param packing automatically.
"""
from axolotl.utils.weight_serde import decode_from_ipc
name, weight = decode_from_ipc(kwargs)
model = self.model_runner.model
model.load_weights(weights=[(name, weight)])
def http_load_weights_batch(self, params: list[dict]):
"""Load multiple weights in a single IPC call.
Uses vLLM's ``load_weights`` which handles TP sharding automatically.
"""
from axolotl.utils.weight_serde import decode_from_ipc
model = self.model_runner.model
weights = [decode_from_ipc(p) for p in params]
model.load_weights(weights=weights)

View File

@@ -138,7 +138,11 @@ def setup_reference_model(
model_ref = None # explicit setting to None model_ref = None # explicit setting to None
else: else:
reference_model: bool = True reference_model: bool = True
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0: trl_cfg = getattr(cfg, "trl", None)
if (
cfg.rl in {RLType.GRPO, RLType.EBFT}
and getattr(trl_cfg, "beta", 0) == 0
):
reference_model = False reference_model = False
# load the model again for model_ref/baseline # load the model again for model_ref/baseline
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model) model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
@@ -206,7 +210,7 @@ def execute_training(
gradient_accumulation_steps=cfg.gradient_accumulation_steps, gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func, ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride, heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO, gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT},
device_mesh=trainer.accelerator.torch_device_mesh, device_mesh=trainer.accelerator.torch_device_mesh,
) )
) )

View File

@@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
].append(pred_step_text) ].append(pred_step_text)
row_index += 1 row_index += 1
if logger == "wandb": if logger == "wandb":
# type: ignore[attr-defined] wandb.run.log( # type: ignore[attr-defined]
wandb.run.log(
{ {
f"{name} - Predictions vs Ground Truth": pd.DataFrame( f"{name} - Predictions vs Ground Truth": pd.DataFrame(
table_data table_data
@@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file: ) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name) copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact( artifact = wandb.Artifact( # type: ignore[attr-defined]
f"config-{wandb.run.id}", type="axolotl-config" f"config-{wandb.run.id}", # type: ignore[attr-defined]
type="axolotl-config",
) )
artifact.add_file(temp_file.name) artifact.add_file(temp_file.name)
wandb.log_artifact(artifact) wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_file.name) wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info( LOG.info(
"The Axolotl config has been saved to the WandB run under files." "The Axolotl config has been saved to the WandB run under files."
) )
@@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
temp_ct_file.write(str(chat_tpl)) temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush() temp_ct_file.flush()
artifact = wandb.Artifact( artifact = wandb.Artifact( # type: ignore[attr-defined]
f"chat-template-{wandb.run.id}", type="jinja-template" f"chat-template-{wandb.run.id}", # type: ignore[attr-defined]
type="jinja-template",
) )
artifact.add_file(temp_ct_file.name) artifact.add_file(temp_ct_file.name)
wandb.log_artifact(artifact) wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_ct_file.name) wandb.save(temp_ct_file.name) # type: ignore[attr-defined]
LOG.info( LOG.info(
"The chat_template_jinja has been saved to the WandB run under files." "The chat_template_jinja has been saved to the WandB run under files."
) )
@@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
else: else:
skip_upload = True skip_upload = True
if not skip_upload: if not skip_upload:
artifact = wandb.Artifact( artifact = wandb.Artifact( # type: ignore[attr-defined]
f"deepspeed-config-{wandb.run.id}", f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined]
type="deepspeed-config", type="deepspeed-config",
) )
artifact.add_file(temp_file.name) artifact.add_file(temp_file.name)
wandb.log_artifact(artifact) wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_file.name) wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info( LOG.info(
"The DeepSpeed config has been saved to the WandB run under files." "The DeepSpeed config has been saved to the WandB run under files."
) )

View File

@@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback):
if not getattr(cfg, "generate_samples", False): if not getattr(cfg, "generate_samples", False):
return return
dataloader = None dataloader = None
try: try:
if getattr(self.trainer, "eval_dataset", None) is not None: if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader() dataloader = self.trainer.get_eval_dataloader()
LOG.info(
f"Using eval dataloader for generation at step {state.global_step}"
)
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
LOG.info( LOG.info(
f"Using train dataloader for generation at step {state.global_step}" f"Using eval dataloader for generation at step {state.global_step}"
) )
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
samples = generate_samples( if dataloader is None:
model=self.trainer.model, dataloader = self.trainer.get_train_dataloader()
tokenizer=self.trainer.processing_class, LOG.info(
dataloader=dataloader, f"Using train dataloader for generation at step {state.global_step}"
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
) )
self._log_samples(samples, state.global_step)
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
)
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int): def _log_samples(self, samples: list, step: int):
"""Log generated samples to console and W&B.""" """Log generated samples to console and W&B."""
@@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback):
try: try:
import wandb import wandb
if wandb.run is not None: if wandb.run is not None: # type: ignore[attr-defined]
wandb.log( wandb.log( # type: ignore[attr-defined]
{ {
f"samples/sample_{i + 1}": wandb.Html( f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined]
f"<pre>{wandb_text}</pre>" f"<pre>{wandb_text}</pre>"
) )
}, },

View File

@@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.lock import FileLockLoader from axolotl.utils.data.lock import FileLockLoader
@@ -173,7 +174,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len return (len_prompt + len_completion) <= sequence_len
if rl in {RLType.GRPO, RLType.GDPO}: if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}:
return True return True
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
@@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i) ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO: elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i) ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.EBFT:
ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i)
else: else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i) ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
map_kwargs: dict[str, Any] = {} map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple): if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn ds_transform_fn, map_kwargs = ds_transform_fn
# Handle remove_columns: "__all__" removes all original columns,
# or filter a list to only columns that exist in the dataset
if "remove_columns" in map_kwargs:
ds_columns = (
dataset.column_names
if isinstance(dataset, Dataset)
else dataset[split].column_names
if isinstance(dataset, DatasetDict)
else []
)
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
split_datasets[i] = _map_dataset( split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
) )

View File

@@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__) LOG = get_logger(__name__)
class EBFTConfig(BaseModel):
"""Configuration for Energy-Based Fine-Tuning (EBFT)"""
feature_layers: list[float] = Field(
default=[0.25, 0.5, 0.75],
json_schema_extra={
"description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])"
},
)
embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = (
Field(
default="last_token",
json_schema_extra={
"description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'"
},
)
)
use_whitening: bool = Field(
default=False,
json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
)
alignment_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for alignment reward (cosine similarity with ground truth)"
},
)
diversity_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for diversity penalty (pairwise similarity between samples)"
},
)
ce_coef: float = Field(
default=0.0,
json_schema_extra={
"description": "Cross-entropy loss coefficient on ground-truth tokens"
},
)
adaptive_max_tokens: bool = Field(
default=True,
json_schema_extra={
"description": "Set per-batch max_tokens based on ground-truth length"
},
)
gt_length_multiplier: float = Field(
default=1.5,
ge=0.1,
json_schema_extra={
"description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
},
)
# Strided mode fields (for unstructured text)
mode: Literal["structured", "strided"] = Field(
default="structured",
json_schema_extra={
"description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
},
)
stride: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Stride between anchor points (tokens)"},
)
context_length: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Context window size per block"},
)
generate_max_len: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Tokens to generate per block"},
)
n_samples_per_prompt: int = Field(
default=4,
ge=1,
json_schema_extra={"description": "Independent rollouts per document"},
)
temperature: float = Field(
default=0.6,
ge=0.0,
json_schema_extra={
"description": "Sampling temperature for strided generation"
},
)
top_p: float = Field(
default=1.0,
ge=0.0,
le=1.0,
json_schema_extra={"description": "Top-p nucleus sampling threshold"},
)
rl_coef: float = Field(
default=1.0,
json_schema_extra={"description": "RL policy gradient loss coefficient"},
)
advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
default="rloo",
json_schema_extra={
"description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
},
)
min_completion_prefix: int = Field(
default=0,
ge=0,
json_schema_extra={
"description": "Minimum tokens into completion before placing anchors. "
"Skips anchors too close to the prompt boundary where features are dominated by prompt context."
},
)
class AxolotlInputConfig( class AxolotlInputConfig(
ModelInputConfig, ModelInputConfig,
ModelOutputConfig, ModelOutputConfig,
@@ -131,7 +244,7 @@ class AxolotlInputConfig(
rl: RLType | None = Field( rl: RLType | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'" "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'"
}, },
) )
trl: TRLConfig | None = Field( trl: TRLConfig | None = Field(
@@ -140,6 +253,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field( vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), default_factory=lambda: VllmConfig(),
) )
ebft: EBFTConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
},
)
qat: QATConfig | None = None qat: QATConfig | None = None
quantization: PTQConfig | None = None quantization: PTQConfig | None = None
reward_model: bool | None = Field( reward_model: bool | None = Field(

View File

@@ -35,6 +35,7 @@ class RLType(str, Enum):
ORPO = "orpo" ORPO = "orpo"
KTO = "kto" KTO = "kto"
SIMPO = "simpo" SIMPO = "simpo"
EBFT = "ebft"
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):

View File

@@ -1,6 +1,6 @@
"""Pydantic models for TRL trainer configuration""" """Pydantic models for TRL trainer configuration"""
from typing import Literal from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -133,6 +133,20 @@ class TRLConfig(BaseModel):
"description": "Penalty for tokens that appear in prompt and generated text." "description": "Penalty for tokens that appear in prompt and generated text."
}, },
) )
generation_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional generation parameters passed to vLLM SamplingParams. "
"Useful for stop_token_ids, seed, frequency_penalty, etc."
},
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional kwargs for the chat template. "
"E.g., {enable_thinking: false} for Qwen3.5 models."
},
)
num_iterations: int | None = Field( num_iterations: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -1482,6 +1482,124 @@ class DistributedValidationMixin:
return self return self
class EBFTValidationMixin:
"""Validation for EBFT (Energy-Based Fine-Tuning) configuration."""
@model_validator(mode="before")
@classmethod
def check_ebft_config_required(cls, data):
"""rl: ebft requires an ebft config section."""
if data.get("rl") == "ebft" and not data.get("ebft"):
raise ValueError(
"`ebft` config section is required when `rl: ebft` is set. "
"Add an `ebft:` section with at least `mode: structured` or `mode: strided`."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_torch_compile(cls, data):
"""torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles
and CheckpointErrors. The flex_attention kernel compiles itself internally —
whole-model torch.compile is not needed and actively harmful."""
if (
data.get("rl") == "ebft"
and data.get("torch_compile") is True
and data.get("ebft", {}).get("mode") == "strided"
):
if data.get("gradient_checkpointing"):
raise ValueError(
"EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` "
"causes CheckpointError (BlockMask metadata mismatch during recomputation). "
"Remove `torch_compile` — the flex_attention kernel compiles itself internally."
)
LOG.warning(
"EBFT strided mode: `torch_compile: true` causes dynamo recompiles from "
"variable sequence lengths across steps. Consider removing it — "
"flex_attention compiles itself internally."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_gradient_checkpointing_reentrant(cls, data):
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("flex_attention")
and data.get("gradient_checkpointing")
):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
if not gc_kwargs.get("use_reentrant"):
LOG.warning(
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
)
if data.get("gradient_checkpointing_kwargs") is None:
data["gradient_checkpointing_kwargs"] = {}
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
return data
@model_validator(mode="before")
@classmethod
def check_ebft_activation_offloading(cls, data):
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
which conflicts with flex_attention's use_reentrant requirement."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("activation_offloading") is True
and data.get("flex_attention")
):
raise ValueError(
"EBFT strided mode: `activation_offloading: true` is incompatible with "
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
"uses micro-batched forward passes for memory efficiency instead."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_strided_sequence_len(cls, data):
"""Warn if sequence_len is too large for single-GPU strided EBFT."""
if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided":
return data
ebft = data.get("ebft", {})
seq_len = data.get("sequence_len", 512)
n_samples = ebft.get("n_samples_per_prompt", 4)
gen_len = ebft.get("generate_max_len", 8)
stride = ebft.get("stride", 8)
ctx_len = ebft.get("context_length", 8)
max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
full_seq = seq_len + max_blocks * gen_len
# Rough estimate: 8.7 GB per sample at S=3900 for 1B model
if full_seq * n_samples > 20000:
LOG.warning(
f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = "
f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM "
f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_strided_dataset_split(cls, data):
"""Warn about the common `train_on_split` mistake (silently ignored by schema)."""
datasets = data.get("datasets", [])
for ds in datasets or []:
if isinstance(ds, dict) and ds.get("train_on_split"):
LOG.warning(
f"Dataset has `train_on_split: {ds['train_on_split']}` — this field "
f"is not recognized and will be silently ignored. "
f"Use `split: {ds['train_on_split']}` instead."
)
return data
class GRPOVllmValidationMixin: class GRPOVllmValidationMixin:
"""Validation mixin for vllm when using GRPO.""" """Validation mixin for vllm when using GRPO."""
@@ -1507,6 +1625,7 @@ class ValidationMixin(
PretrainingValidationMixin, PretrainingValidationMixin,
ModelCompatibilityValidationMixin, ModelCompatibilityValidationMixin,
ComplexValidationMixin, ComplexValidationMixin,
EBFTValidationMixin,
GRPOVllmValidationMixin, GRPOVllmValidationMixin,
): ):
"""Full validation mixin for Axolotl configuration.""" """Full validation mixin for Axolotl configuration."""

View File

@@ -57,6 +57,13 @@ class VllmConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "Reasoning parser for VLLM"}, json_schema_extra={"description": "Reasoning parser for VLLM"},
) )
enforce_eager: bool | None = Field(
default=None,
json_schema_extra={
"description": "Disable CUDA graph capture in vLLM. Required for models with "
"causal_conv1d (e.g., Qwen3.5 hybrid linear attention)."
},
)
serve_module: str | None = Field( serve_module: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -0,0 +1,94 @@
"""Serialize / deserialize tensors for HTTP and IPC weight sync.
NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire
and reconstructed at the destination. All encode/decode helpers live here so
the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and
vllm_worker_ext.py.
"""
import base64
import torch
def encode_for_http(name: str, weight: torch.Tensor) -> dict:
"""Encode a named parameter for JSON transport over HTTP.
Returns a dict with keys: name, dtype (original), shape, data (base64).
bf16 tensors are sent as fp16 bytes; the original dtype is preserved in
the ``dtype`` field so the receiver can cast back.
"""
w_cpu = weight.contiguous().cpu()
orig_dtype = str(weight.dtype)
if w_cpu.dtype == torch.bfloat16:
w_cpu = w_cpu.half()
raw = w_cpu.numpy().tobytes()
return {
"name": name,
"dtype": orig_dtype,
"shape": list(weight.shape),
"data": base64.b64encode(raw).decode("ascii"),
}
def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]:
"""Decode an HTTP-encoded weight entry back to a named tensor.
Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the
original dtype stored in ``entry["dtype"]``.
"""
target_dtype = getattr(torch, entry["dtype"].split(".")[-1])
shape = tuple(entry["shape"])
raw = base64.b64decode(entry["data"])
n_elements = 1
for s in shape:
n_elements *= s
wire_bytes_per_elem = len(raw) // max(n_elements, 1)
if wire_bytes_per_elem == 2:
wire_dtype = torch.float16
elif wire_bytes_per_elem == 4:
wire_dtype = torch.float32
else:
wire_dtype = target_dtype
weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape)
if wire_dtype != target_dtype:
weight = weight.to(target_dtype)
return entry["name"], weight
def encode_for_ipc(name: str, weight: torch.Tensor) -> dict:
"""Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64).
Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype
(original), shape. bf16 tensors are serialized as fp16.
"""
w = weight.contiguous()
target_dtype = str(w.dtype).split(".")[-1]
if w.dtype == torch.bfloat16:
w = w.half()
wire_dtype = str(w.dtype).split(".")[-1]
return {
"name": name,
"data": w.numpy().tobytes(),
"dtype": wire_dtype,
"target_dtype": target_dtype,
"shape": list(weight.shape),
}
def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]:
"""Decode an IPC-encoded weight entry back to a named tensor.
Handles optional ``target_dtype`` for backward compatibility with older
serve code that may not include it.
"""
wire_dtype = getattr(torch, entry["dtype"])
weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape(
entry["shape"]
)
target_dtype = entry.get("target_dtype")
if target_dtype and target_dtype != entry["dtype"]:
weight = weight.to(getattr(torch, target_dtype))
return entry["name"], weight

294
tests/test_ebft_kernels.py Normal file
View File

@@ -0,0 +1,294 @@
"""Correctness tests for fused EBFT Triton kernels."""
import pytest
import torch
import torch.nn.functional as F
from axolotl.core.trainers.ebft.kernels import (
fused_cosine_similarity,
fused_diversity_penalty,
fused_log_softmax_gather,
fused_reinforce_loss,
)
# Skip all tests if CUDA not available
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
DEVICE = "cuda"
# ---------------------------------------------------------------------------
# 1. fused_log_softmax_gather
# ---------------------------------------------------------------------------
class TestFusedLogSoftmaxGather:
def _reference(self, logits, labels):
"""PyTorch reference: log_softmax + gather."""
lp = F.log_softmax(logits.float(), dim=-1)
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
def test_basic_correctness(self):
B, S, V = 2, 16, 1024
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_large_vocab(self):
"""Test with realistic vocab size (128K)."""
B, S, V = 1, 8, 128256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_fp32_input(self):
B, S, V = 2, 8, 512
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def test_output_is_negative(self):
"""log_softmax values should always be <= 0."""
B, S, V = 4, 32, 2048
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
out = fused_log_softmax_gather(logits, labels)
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
def test_extreme_logits(self):
"""Test numerical stability with very large/small logits."""
B, S, V = 1, 4, 256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
logits[:, 0, :] = 1000.0 # very large
logits[:, 1, :] = -1000.0 # very small
logits[:, 2, 0] = 1000.0 # one hot-ish
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
assert torch.isfinite(out).all(), "Should handle extreme values"
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with pre-flattened (N, V) input."""
N, V = 64, 4096
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (N,), device=DEVICE)
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 2. fused_reinforce_loss
# ---------------------------------------------------------------------------
class TestFusedReinforceLoss:
def _reference(self, logps, advantages, mask):
"""PyTorch reference implementation."""
loss_per_token = -logps * advantages
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
def test_basic_correctness(self):
N = 1024
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with (B, S) shaped inputs."""
B, S = 4, 256
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_all_masked(self):
"""All-zero mask should return 0."""
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
out = fused_reinforce_loss(logps, advs, mask)
assert out.item() == 0.0
def test_all_unmasked(self):
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic size (4 * 3000 tokens)."""
N = 12000
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 3. fused_cosine_similarity
# ---------------------------------------------------------------------------
class TestFusedCosineSimilarity:
def test_basic_correctness(self):
N, D = 64, 256
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_batched(self):
"""Test with (B, N, NB, D) shaped input."""
B, N, NB, D = 2, 4, 16, 512
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_identical_vectors(self):
"""Identical vectors should give similarity = 1."""
N, D = 16, 128
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, a)
torch.testing.assert_close(
out,
torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_orthogonal_vectors(self):
"""Orthogonal vectors should give similarity = 0."""
D = 128
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
a[0, 0] = 1.0
b[0, 1] = 1.0
out = fused_cosine_similarity(a, b)
assert abs(out.item()) < 1e-5
def test_opposite_vectors(self):
"""Opposite vectors should give similarity = -1."""
N, D = 8, 64
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, -a)
torch.testing.assert_close(
out,
-torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_large_dimension(self):
"""Test with large feature dimension (multi-layer concatenated features)."""
N, D = 32, 4608 # 3 layers * 1536 hidden
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
# ---------------------------------------------------------------------------
# 4. fused_diversity_penalty
# ---------------------------------------------------------------------------
class TestFusedDiversityPenalty:
def _reference(self, embeddings):
"""PyTorch reference: bmm + mask diagonal + mean."""
B, N, D = embeddings.shape
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
return sims.sum(dim=-1) / (N - 1)
def test_basic_correctness(self):
B, N, D = 4, 4, 256
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_two_samples(self):
"""With n=2, diversity = dot(a, b) for each."""
B, D = 3, 128
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_identical_samples(self):
"""All identical samples should give max diversity."""
B, N, D = 2, 4, 64
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
emb = vec.expand(B, N, D).contiguous()
out = fused_diversity_penalty(emb)
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic EBFT dimensions."""
B, N, D = 1, 4, 4608 # multi-layer features
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
def test_single_sample_returns_zeros(self):
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
B, D = 3, 128
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
out = fused_diversity_penalty(emb)
assert (out == 0).all(), "N=1 diversity should be exactly zero"

View File

@@ -0,0 +1,363 @@
"""Tests for the EBFT strided structured dataset transform and data loading."""
import pytest
from datasets import Dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
@pytest.fixture
def tokenizer():
"""Create a simple word-level tokenizer — no network access needed."""
# Build a tiny vocab covering common test words
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
words = (
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
"x write code print test some string metadata noise ok good python "
"sampling abc 123 def solve return this that"
).split()
for w in words:
if w not in vocab:
vocab[w] = len(vocab)
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
backend.pre_tokenizer = pre_tokenizers.Whitespace()
tok = PreTrainedTokenizerFast(
tokenizer_object=backend,
bos_token="[BOS]",
eos_token="[EOS]",
pad_token="[PAD]",
unk_token="[UNK]",
)
return tok
@pytest.fixture
def cfg():
return DictDefault({"sequence_len": 64})
@pytest.fixture
def transform_fn_and_kwargs(cfg):
result = load_ebft("ebft_strided_structured.transform", cfg)
assert result is not None, "Failed to load ebft_strided_structured transform"
transform_fn, map_kwargs = result
return transform_fn, map_kwargs
class TestEBFTStridedStructuredTransform:
"""Tests for the dataset transform function itself."""
def test_transform_loads(self, transform_fn_and_kwargs):
transform_fn, map_kwargs = transform_fn_and_kwargs
assert callable(transform_fn)
assert "remove_columns" in map_kwargs
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
"""Transform should signal removal of all original columns."""
_, map_kwargs = transform_fn_and_kwargs
assert map_kwargs["remove_columns"] == "__all__"
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
"""Prompt tokens get labels=-100, completion tokens get real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
result = transform_fn(example, tokenizer=tokenizer)
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "prompt_length" in result
prompt_length = result["prompt_length"]
labels = result["labels"]
seq_len = len(result["input_ids"])
assert seq_len == 64, "Should be padded to sequence_len"
assert len(labels) == seq_len
assert prompt_length > 0
# Prompt tokens should be masked
for i in range(prompt_length):
assert labels[i] == -100, f"Prompt token at {i} should be -100"
# At least one completion token should have a real label
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
"""prompt_length should be the exact boundary between -100 and real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hello world", "output": "goodbye world"}
result = transform_fn(example, tokenizer=tokenizer)
prompt_length = result["prompt_length"]
labels = result["labels"]
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
assert labels[prompt_length] != -100, (
"First completion token should not be masked"
)
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
"""Padding tokens should have labels=-100 and attention_mask=0."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hi", "output": "bye"}
result = transform_fn(example, tokenizer=tokenizer)
attention_mask = result["attention_mask"]
labels = result["labels"]
real_len = sum(attention_mask)
assert real_len < 64, "Short example should have padding"
for i in range(real_len, 64):
assert attention_mask[i] == 0, (
f"Pad position {i} should have attention_mask=0"
)
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
"""Long completions should be truncated to fit sequence_len."""
transform_fn, _ = transform_fn_and_kwargs
example = {
"input": "short prompt",
"output": "x " * 500,
}
result = transform_fn(example, tokenizer=tokenizer)
assert len(result["input_ids"]) == 64
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
"""Transform should handle different field name conventions."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn(
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
result = transform_fn(
{"question": "what", "answer": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
"""Without tokenizer, should return a prompt string."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn({"input": "hello", "output": "world"})
assert "prompt" in result
assert result["prompt"] == "hello"
class TestEBFTColumnRemoval:
"""Tests for the __all__ column removal logic in the RL data path."""
def _filter_remove_columns(self, map_kwargs, dataset):
"""Reproduce the filtering logic from rl.py _load_split."""
if "remove_columns" in map_kwargs:
ds_columns = dataset.column_names
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
return map_kwargs
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""After mapping, only tokenized columns should remain."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs) # copy
ds = Dataset.from_list(
[
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert "input" in map_kwargs["remove_columns"]
assert "output" in map_kwargs["remove_columns"]
assert "extra_field" in map_kwargs["remove_columns"]
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
assert "input_ids" in mapped.column_names
assert "labels" in mapped.column_names
assert "prompt_length" in mapped.column_names
assert "input" not in mapped.column_names
assert "output" not in mapped.column_names
assert "extra_field" not in mapped.column_names
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""Datasets with many extra metadata columns should all be cleaned up."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{
"input": "write hello world",
"output": "print hello",
"id": "abc 123",
"domain": "python",
"generation_algorithm": "sampling",
"llm_judgement": "good",
"unit_tests": "test",
"tests_execution_status": "ok",
"average_test_score": 0.95,
},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert len(map_kwargs["remove_columns"]) == 9
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
assert set(mapped.column_names) == expected_columns
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
"""No string-typed columns should remain (would crash the DataLoader)."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{"input": "test", "output": "test", "notes": "some string metadata"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
for col in mapped.column_names:
feat = mapped.features[col]
assert str(feat) != "string", (
f"Column '{col}' is still a string — would crash DataLoader"
)
def test_filter_preserves_explicit_list(self):
"""When remove_columns is an explicit list, only existing columns are kept."""
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
ds_columns = ds.column_names
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
assert map_kwargs["remove_columns"] == ["a", "b"]
assert "missing_col" not in map_kwargs["remove_columns"]
class TestMultiTurnSeparators:
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
def test_multiturn_transform_splits_turns(self):
"""Transform should store first turn as GT and remaining turns separately."""
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"sequence_len": 512})
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
out = fn(
{
"messages": [
{"role": "user", "content": "Q1"},
{"role": "assistant", "content": "A1"},
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
}
)
# ground_truth is only the first assistant turn
assert out["ground_truth"] == "A1"
# remaining_turns carries the rest for trainer-side reconstruction
assert out["remaining_turns"] == [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
def test_multiturn_gt_reconstruction_via_chat_template(self):
"""Trainer-side GT reconstruction should insert role markers between turns.
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
GT using apply_chat_template, ensuring assistant turns are separated by
role markers rather than concatenated as raw text.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
# Simulate the transform output
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
remaining_turns = [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
gt_conv.extend(remaining_turns)
full_gt_text = tokenizer.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
# The full GT text should contain both assistant turns with role markers
assert "A1" in full_gt_text
assert "A2" in full_gt_text
# Raw concatenation "A1A2" should NOT appear — role markers separate them
assert "A1A2" not in full_gt_text, (
"GT reconstruction should have role markers between turns, not raw concatenation"
)
# The user turn Q2 should appear between A1 and A2
a1_pos = full_gt_text.index("A1")
a2_pos = full_gt_text.index("A2")
q2_pos = full_gt_text.index("Q2")
assert a1_pos < q2_pos < a2_pos, (
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
)
# The GT should start with the prompt
assert full_gt_text.startswith(prompt_text), (
"Full GT should start with the rendered prompt"
)
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
# remaining_turns would be [] for single-turn prompts
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
# With empty remaining_turns, trainer falls through to raw concat
# (trainer.py:302: gt_texts.append(prompt_text + gt))
gt_text = prompt_text + gt
assert gt_text.endswith("A1")
assert prompt_text in gt_text

View File

@@ -0,0 +1,158 @@
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
"""
import pytest
import torch
from axolotl.utils.weight_serde import (
decode_from_http,
decode_from_ipc,
encode_for_http,
encode_for_ipc,
)
# ---------------------------------------------------------------------------
# Stage 1: trainer → serve endpoint (HTTP with base64)
# ---------------------------------------------------------------------------
class TestHttpEncodeRoundTrip:
"""Test encode_for_http / decode_from_http."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_http("layer.weight", original)
name, decoded = decode_from_http(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
# bf16→fp16→bf16 loses some precision
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_wire_format_is_fp16(self):
"""bf16 tensors should be sent as fp16 on the wire."""
import base64
original = torch.randn(8, 16, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
raw = base64.b64decode(entry["data"])
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
assert len(raw) == 8 * 16 * 2
# dtype field should preserve original dtype for reconstruction
assert entry["dtype"] == "torch.bfloat16"
def test_multidimensional_shapes(self):
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
original = torch.randn(*shape, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
_, decoded = decode_from_http(entry)
assert decoded.shape == original.shape
assert decoded.dtype == torch.bfloat16
# ---------------------------------------------------------------------------
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
# ---------------------------------------------------------------------------
class TestIpcEncodeRoundTrip:
"""Test encode_for_ipc / decode_from_ipc."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_ipc("layer.weight", original)
name, decoded = decode_from_ipc(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_ipc_wire_is_fp16(self):
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
original = torch.randn(4, 8, dtype=torch.bfloat16)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float16"
assert entry["target_dtype"] == "bfloat16"
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
def test_fp32_has_no_target_dtype_mismatch(self):
original = torch.randn(4, 8, dtype=torch.float32)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float32"
assert entry["target_dtype"] == "float32"
def test_worker_handles_missing_target_dtype(self):
"""Backward compat: older serve code may not send target_dtype."""
entry = {
"name": "w",
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
"dtype": "float32",
"shape": [4, 8],
# no target_dtype key
}
name, decoded = decode_from_ipc(entry)
assert decoded.dtype == torch.float32
assert decoded.shape == (4, 8)
# ---------------------------------------------------------------------------
# Full pipeline: trainer → serve → worker
# ---------------------------------------------------------------------------
class TestFullPipelineRoundTrip:
"""End-to-end: trainer → serve → worker."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_three_stage_round_trip(self, dtype):
"""Tensor survives trainer→serve→worker with correct dtype and values."""
original = torch.randn(16, 32, dtype=dtype)
# Stage 1: trainer encodes for HTTP
http_entry = encode_for_http("model.layers.0.weight", original)
# Stage 2: serve decodes HTTP, re-encodes for IPC
name, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc(name, at_serve)
# Stage 3: worker decodes IPC
_, at_worker = decode_from_ipc(ipc_entry)
assert at_worker.dtype == dtype
assert at_worker.shape == original.shape
if dtype == torch.bfloat16:
# Two bf16→fp16→bf16 hops compound precision loss slightly
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
else:
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
def test_bfloat16_precision_loss_is_bounded(self):
"""bf16→fp16→bf16 round-trip error should be small."""
original = torch.randn(256, 256, dtype=torch.bfloat16)
http_entry = encode_for_http("w", original)
_, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc("w", at_serve)
_, at_worker = decode_from_ipc(ipc_entry)
max_err = (at_worker.float() - original.float()).abs().max().item()
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
def test_bfloat16_numpy_would_crash_without_fix(self):
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
t = torch.randn(4, 4, dtype=torch.bfloat16)
with pytest.raises((RuntimeError, TypeError)):
t.numpy()