EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
This commit is contained in:
Wing Lian
2026-03-24 18:43:46 -04:00
committed by GitHub
parent e9883c91d4
commit c50c4acbf4
48 changed files with 5885 additions and 168 deletions

View File

@@ -22,6 +22,7 @@ RUN apt update && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf

214
examples/ebft/README.md Normal file
View File

@@ -0,0 +1,214 @@
# Energy-Based Fine-Tuning (EBFT)
EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
## Overview
EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
**Key advantages over SFT:**
- Operates on model rollouts (not teacher forcing), reducing distribution shift
- Provides dense sequence-level supervision without a task-specific verifier
- Improves both downstream accuracy and validation cross-entropy simultaneously
**Key advantages over RLVR:**
- No reward model or verifier required — works on any (prompt, completion) data
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
- Maintains distributional calibration (low feature-matching loss)
## Two Modes
EBFT supports two modes depending on your data format:
### Structured Mode (`mode: structured`, default)
For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
- Extends GRPOTrainer — uses vLLM for fast rollout generation
- RLOO advantages and clipped policy gradient from GRPO
- Feature-matching rewards replace external reward functions
### Strided Mode (`mode: strided`)
For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
- No vLLM needed — generation uses custom strided attention masks
- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
- Compatible with gradient checkpointing via automatic dtype normalization
- This is the core EBFT algorithm from the paper (Section F)
### Common to both modes:
- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
## Quick Start
### Structured Mode (QA data + vLLM)
```bash
# 1. Start vLLM server
python -m trl.scripts.vllm_serve \
--model meta-llama/Llama-3.2-1B \
--host 0.0.0.0 --port 8000 \
--gpu-memory-utilization 0.3
# 2. Train
axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
```
### Strided Mode (unstructured text)
```bash
# No vLLM needed — strided generation is built-in
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
```
## Configuration
### Common EBFT Settings
```yaml
rl: ebft
ebft:
# Feature network: which layers to extract hidden states from
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
feature_layers: [0.25, 0.5, 0.75]
# How to pool per-token hidden states into sequence embeddings
# Options: "last_token" (recommended), "mean_pooling", "concat"
embed_method: last_token
# SVD whitening — strongly recommended (paper shows largest degradation without it)
use_whitening: true
# Reward = alignment_coef * alignment - diversity_coef * diversity
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
# diversity uses raw dot product — both are bounded after whitening.
alignment_coef: 1.0
diversity_coef: 1.0
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
ce_coef: 0.0
```
### Strided Mode Settings
```yaml
ebft:
mode: strided
stride: 8 # tokens between anchor points (paper default: 8)
context_length: 8 # context window per block (paper default: 8)
generate_max_len: 8 # tokens generated per block (paper default: 8)
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
temperature: 0.6
rl_coef: 1.0 # RL loss weight
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
```
### Structured Mode Settings (via TRL)
```yaml
trl:
num_generations: 4 # samples per prompt
max_completion_length: 256 # max tokens to generate
temperature: 1.0
use_vllm: true
scale_rewards: true
loss_type: grpo
epsilon: 0.2
```
### Dataset Format
**Structured mode** — QA data with prompt + ground-truth completion:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
```
Transform returns: `{"prompt": ..., "ground_truth": ...}`
**Strided mode** — raw text tokenized to fixed length:
```yaml
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
```
Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
## How It Works
### Structured Mode
1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
4. **RLOO advantages**: subtract leave-one-out group mean
5. **Policy gradient**: clipped PPO-style loss
### Strided Mode
1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
4. **Per-block rewards**:
- **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
- **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
- **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
### Tracked Metrics
| Metric | Description |
|--------|-------------|
| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
| `ebft/mean_reward` | alignment - diversity (should trend upward) |
| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
| `ebft/rl_loss` | REINFORCE policy gradient loss |
| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
## Tips and Recommendations
### Reward coefficients
- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
### Feature extraction
- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
### Performance
- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
### Memory
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
## Example Configs
| Config | Mode | Model | Description |
|--------|------|-------|-------------|
| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
## Citation
```bibtex
@article{jelassi2026matching,
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
journal={arXiv preprint arXiv:2603.12248},
year={2026}
}
```

View File

@@ -0,0 +1,28 @@
"""
Dataset transform for nvidia/OpenCodeInstruct with EBFT.
Maps the dataset's `input` (prompt) and `output` (code solution) fields
to the format expected by the EBFT trainer.
"""
def transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "user", "content": example["input"]},
],
"ground_truth": example["output"],
}
return transform_fn, {
"remove_columns": [
"id",
"domain",
"generation_algorithm",
"llm_judgement",
"unit_tests",
"tests_execution_status",
"average_test_score",
]
}

View File

@@ -0,0 +1,31 @@
"""
Dataset transform for unstructured text data with strided EBFT.
Tokenizes raw text into fixed-length input_ids for the strided trainer.
Sequences are padded to sequence_len for uniform batching.
"""
def transform(cfg, *args, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
text = example.get("question", example.get("text", ""))
if tokenizer is None:
return {"prompt": text}
encoded = tokenizer(
text,
truncation=True,
max_length=seq_len,
padding="max_length",
add_special_tokens=True,
return_tensors=None,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": list(encoded["input_ids"]),
}
return transform_fn, {"remove_columns": ["question", "answer"]}

View File

@@ -0,0 +1,80 @@
"""
Dataset transform for structured (prompt, completion) data with strided EBFT.
Tokenizes prompt and completion separately, concatenates into a single
input_ids sequence, and marks prompt tokens with labels=-100 so the
strided trainer knows where to place anchors (completion span only).
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
"""
def transform(cfg, *args, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
# Extract prompt and completion from the example
prompt_text = example.get(
"input", example.get("prompt", example.get("question", ""))
)
completion_text = example.get(
"output", example.get("completion", example.get("answer", ""))
)
if tokenizer is None:
return {"prompt": prompt_text}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize prompt and completion separately
prompt_enc = tokenizer(
prompt_text,
truncation=False,
add_special_tokens=True,
return_tensors=None,
)
completion_enc = tokenizer(
completion_text,
truncation=False,
add_special_tokens=False,
return_tensors=None,
)
prompt_ids = prompt_enc["input_ids"]
completion_ids = completion_enc["input_ids"]
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
total_len = len(prompt_ids) + len(completion_ids)
if total_len > seq_len:
# Truncate completion first, then prompt if needed
max_completion = seq_len - len(prompt_ids)
if max_completion < 1:
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
prompt_ids = prompt_ids[: seq_len - 1]
completion_ids = completion_ids[:1]
else:
completion_ids = completion_ids[:max_completion]
input_ids = prompt_ids + completion_ids
prompt_length = len(prompt_ids)
# Labels: -100 for prompt tokens, input_ids for completion tokens
labels = [-100] * prompt_length + completion_ids
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
# Signal to remove all original columns (filtered to existing ones at map time)
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,64 @@
# EBFT validation config — no vLLM, uses HF generate for simplicity
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
base_model: meta-llama/Llama-3.2-1B
chat_template: llama3
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 128
temperature: 1.0
use_vllm: false
scale_rewards: true
loss_type: grpo
epsilon: 0.2
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:1%]
sequence_len: 512
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
max_steps: 10
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 2
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-validation
wandb_project: ebft
wandb_run_id:
wandb_watch:
wandb_log_model:
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,81 @@
# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
#
# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
# https://arxiv.org/abs/2603.12248
#
# Prerequisites:
# 1. Start vLLM server on a separate GPU:
# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
# --model meta-llama/Llama-3.2-1B \
# --host 0.0.0.0 --port 8000 \
# --gpu-memory-utilization 0.4 --dtype bfloat16
#
# 2. Run training:
# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
base_model: meta-llama/Llama-3.2-1B
chat_template: llama3
# --- Training method ---
rl: ebft
# --- EBFT configuration ---
ebft:
feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
embed_method: last_token # pool to sequence embedding via last token
use_whitening: false # SVD whitening (disable for speed in small runs)
alignment_coef: 1.0 # cosine similarity with ground-truth features
diversity_coef: 1.0 # pairwise similarity penalty
ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
# --- Generation settings (via TRL/GRPO infrastructure) ---
trl:
num_generations: 4 # samples per prompt for RLOO
max_completion_length: 256 # max generated tokens
temperature: 1.0
use_vllm: true
scale_rewards: true
loss_type: grpo
epsilon: 0.2
# --- Dataset ---
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:1%] # first 1% for validation runs
# --- Training hyperparameters ---
sequence_len: 1024
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 50
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 5
weight_decay: 0.01
# --- LoRA (recommended to reduce memory with frozen feature network) ---
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
# --- Hardware ---
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama-1b-opencode
# --- Logging ---
use_tensorboard: true
logging_steps: 1
save_steps: 25

View File

@@ -0,0 +1,65 @@
# EBFT Strided Structured Mode: For structured (prompt, completion) data
# Uses strided block-parallel generation on completion spans — no vLLM needed.
#
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided # strided block-parallel generation
stride: 8 # tokens between anchor points
context_length: 8 # context window per block
generate_max_len: 8 # tokens to generate per block
n_samples_per_prompt: 4 # rollouts per document
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.03 # small CE weight for structured data
advantage_estimator: rloo
min_completion_prefix: 8 # skip anchors too close to prompt boundary
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 1
# max_steps: 10
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 5
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
# (full-model compile conflicts with gradient checkpointing + flex_attention)
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-strided-structured
wandb_project: ebft
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,60 @@
# EBFT Strided Mode: For unstructured text data (raw code, prose)
# Uses strided block-parallel generation — no vLLM needed.
#
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided # strided block-parallel generation
stride: 8 # tokens between anchor points
context_length: 8 # context window per block
generate_max_len: 8 # tokens to generate per block
n_samples_per_prompt: 4 # rollouts per document
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:100]
sequence_len: 256
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 1
max_steps: 5
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 2
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
gradient_checkpointing: true
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-strided-validation
wandb_project: ebft
logging_steps: 1
save_steps: 100

View File

@@ -0,0 +1,69 @@
# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
# Actor on GPU 0, frozen feature network on GPU 1
#
# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
base_model: meta-llama/Llama-3.2-3B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:5000]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 100
learning_rate: 1.0e-5
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
weight_decay: 0.01
adapter: lora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
bf16: auto
torch_dtype: bfloat16
flash_attention: false
gradient_checkpointing: true
torch_compile: true
gradient_checkpointing_kwargs:
use_reentrant: true
ddp: false
device_map:
"": 0
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama3b-strided
wandb_project: ebft
wandb_name: llama3b-strided-lora-100steps
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,58 @@
# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
# Feature network is CPU-offloaded to fit in single 32GB GPU
#
# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
base_model: meta-llama/Llama-3.1-8B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
top_p: 1.0
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.0
advantage_estimator: rloo
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
split: train[:5000]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 100
learning_rate: 1.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
weight_decay: 0.01
bf16: auto
flash_attention: false # strided EBFT uses flex_attention at runtime
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
special_tokens:
pad_token: "<|end_of_text|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-llama8b-strided
wandb_project: ebft
wandb_name: llama8b-strided-fft-100steps
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,86 @@
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
#
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
# full attention every 4th layer. This tests EBFT compatibility.
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
base_model: Qwen/Qwen3.5-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false
async_prefetch: true
vllm_server_timeout: 300
vllm:
gpu_memory_utilization: 0.5
max_model_len: 2048
serve_module: axolotl.scripts.vllm_serve_lora
enforce_eager: true
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-4b-structured-async
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,77 @@
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
#
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
# full attention every 4th layer. This tests EBFT compatibility.
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
base_model: Qwen/Qwen3.5-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-4b-structured
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -0,0 +1,82 @@
# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
#
# Prerequisites:
# 1. Start vLLM on GPU 0:
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
#
# 2. Run training on GPU 1:
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
base_model: Qwen/Qwen3.5-9B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: false
alignment_coef: 1.0
diversity_coef: 1.0
ce_coef: 0.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
scale_rewards: true
loss_type: grpo
epsilon: 0.2
generation_kwargs:
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
chat_template_kwargs:
enable_thinking: false
vllm_server_timeout: 300
vllm:
gpu_memory_utilization: 0.7
max_model_len: 2048
serve_module: axolotl.scripts.vllm_serve_lora
enforce_eager: true
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
num_epochs: 1
max_steps: 10
learning_rate: 3.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 3
weight_decay: 0.01
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
val_set_size: 0.0
output_dir: ./outputs/ebft-qwen35-9b-structured
wandb_project: ebft
logging_steps: 1
save_steps: 50

View File

@@ -38,18 +38,14 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
# Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
# We default to axolotl's serve module instead of TRL's because TRL's sends
# truncate_prompt_tokens which is unsupported in vLLM 0.17+.
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
serve_module = "axolotl.scripts.vllm_serve_lora"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -79,6 +75,12 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
cli_enforce_eager = cli_args.get("enforce_eager")
cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
raw_enforce_eager = (
cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
)
enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
base_kwargs = dict(
model=model,
tensor_parallel_size=tensor_parallel_size,
@@ -89,6 +91,7 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
enforce_eager=enforce_eager,
)
# Use LoRAScriptArguments when serving with native LoRA support
@@ -98,6 +101,10 @@ def do_vllm_serve(
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
# Disable native LoRA in vLLM if not using vllm_lora_sync
# (merged weight sync via batch_update doesn't need vLLM LoRA mode)
if not getattr(cfg.trl, "vllm_lora_sync", False):
lora_kwargs["enable_lora"] = False
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(

View File

@@ -118,7 +118,7 @@ def load_preference_datasets(
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

View File

@@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
elif self.cfg.rl is RLType.EBFT:
from axolotl.core.trainers.ebft import EBFTStrategy
trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
elif self.cfg.rl is RLType.EBFT:
from axolotl.core.trainers.ebft import EBFTStrategy
training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:

View File

@@ -4,6 +4,8 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .ebft.strided import AxolotlStridedEBFTTrainer
from .ebft.trainer import AxolotlEBFTTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,

View File

@@ -0,0 +1,213 @@
"""EBFT (Energy-Based Fine-Tuning) Strategy for training
Two modes:
- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
"""
from typing import Any
from axolotl.core.trainers.ebft.args import (
AxolotlAsyncEBFTConfig,
AxolotlEBFTConfig,
AxolotlStridedEBFTConfig,
)
from axolotl.utils.dict import DictDefault
def _get_ebft_mode(cfg: DictDefault) -> str:
"""Determine EBFT mode from config."""
if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
return cfg.ebft.mode
return "structured"
class EBFTStrategy:
"""Strategy for EBFT training — dispatches between structured and strided modes."""
@classmethod
def get_trainer_class(cls, cfg: DictDefault | None = None):
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
return AxolotlStridedEBFTTrainer
# Structured mode: async or sync
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
if use_async:
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
return AxolotlAsyncEBFTTrainer
from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
return AxolotlEBFTTrainer
@classmethod
def get_training_args_class(cls, cfg: DictDefault | None = None):
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
return AxolotlStridedEBFTConfig
# Structured mode: async or sync config
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
if use_async:
return AxolotlAsyncEBFTConfig
return AxolotlEBFTConfig
@classmethod
def is_strided(cls, cfg: DictDefault) -> bool:
return _get_ebft_mode(cfg) == "strided"
@classmethod
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
"""Map axolotl YAML config fields to training args kwargs."""
kwargs: dict[str, Any] = {}
mode = _get_ebft_mode(cfg)
# Common EBFT fields
ebft = cfg.ebft
if ebft:
if ebft.feature_layers is not None:
kwargs["ebft_feature_layers"] = ebft.feature_layers
if ebft.embed_method is not None:
kwargs["ebft_embed_method"] = ebft.embed_method
if ebft.use_whitening is not None:
kwargs["ebft_use_whitening"] = ebft.use_whitening
if ebft.alignment_coef is not None:
kwargs["ebft_alignment_coef"] = ebft.alignment_coef
if ebft.diversity_coef is not None:
kwargs["ebft_diversity_coef"] = ebft.diversity_coef
if ebft.ce_coef is not None:
kwargs["ebft_ce_coef"] = ebft.ce_coef
if getattr(ebft, "adaptive_max_tokens", None) is not None:
kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
if getattr(ebft, "gt_length_multiplier", None) is not None:
kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
if mode == "strided":
# Strided-specific fields
if ebft:
if ebft.stride is not None:
kwargs["ebft_stride"] = ebft.stride
if ebft.context_length is not None:
kwargs["ebft_context_length"] = ebft.context_length
if ebft.generate_max_len is not None:
kwargs["ebft_generate_max_len"] = ebft.generate_max_len
if ebft.n_samples_per_prompt is not None:
kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
if ebft.temperature is not None:
kwargs["ebft_temperature"] = ebft.temperature
if ebft.top_p is not None:
kwargs["ebft_top_p"] = ebft.top_p
if ebft.rl_coef is not None:
kwargs["ebft_rl_coef"] = ebft.rl_coef
if ebft.advantage_estimator is not None:
kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
if ebft.min_completion_prefix is not None:
kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
else:
# Structured mode: map TRL config fields
trl = cfg.trl
if trl:
if trl.use_vllm:
kwargs["use_vllm"] = trl.use_vllm
if trl.vllm_mode:
kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
vllm_cfg = cfg.vllm
if vllm_cfg:
kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)
kwargs["vllm_tensor_parallel_size"] = (
vllm_cfg.tensor_parallel_size
)
kwargs["vllm_server_host"] = trl.vllm_server_host or (
trl.vllm.host if trl.vllm else None
)
kwargs["vllm_server_port"] = trl.vllm_server_port or (
trl.vllm.port if trl.vllm else None
)
if trl.vllm_server_timeout:
kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.num_generations:
kwargs["num_generations"] = trl.num_generations
if trl.max_completion_length is not None:
kwargs["max_completion_length"] = trl.max_completion_length
if trl.temperature is not None:
kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
kwargs["min_p"] = trl.min_p
if trl.num_iterations is not None:
kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
kwargs["epsilon_high"] = trl.epsilon_high
if trl.scale_rewards is not None:
kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.log_completions is not None:
kwargs["log_completions"] = trl.log_completions
if trl.num_completions_to_print is not None:
kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.sync_ref_model:
kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.repetition_penalty is not None:
kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.generation_kwargs is not None:
kwargs["generation_kwargs"] = trl.generation_kwargs
if trl.chat_template_kwargs is not None:
kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
# Async prefetch fields (only pass when enabled — sync config doesn't have these)
if getattr(trl, "async_prefetch", False):
kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "vllm_sync_interval", None) is not None:
kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "vllm_lora_sync", False):
kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return kwargs
@classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
return []
@classmethod
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
return {}
@classmethod
def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
mode = _get_ebft_mode(cfg) if cfg else "structured"
if mode == "strided":
return [
"dataset_num_proc",
"max_length",
"max_prompt_length",
"include_tokens_per_second",
"beta",
]
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod
def get_collator(cls, *args, **kwargs):
return None

View File

@@ -0,0 +1,133 @@
"""
EBFT-specific training arguments.
Two config classes:
- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
"""
from dataclasses import dataclass, field
from transformers import TrainingArguments
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
# -- Shared EBFT fields as a mixin --
@dataclass
class EBFTFieldsMixin:
"""Common fields shared between structured and strided EBFT configs."""
ebft_feature_layers: list[float] = field(
default_factory=lambda: [0.25, 0.5, 0.75],
metadata={"help": "Fractional layer depths for feature extraction"},
)
ebft_embed_method: str = field(
default="last_token",
metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
)
ebft_use_whitening: bool = field(
default=False,
metadata={"help": "Apply SVD whitening to feature embeddings"},
)
ebft_alignment_coef: float = field(
default=1.0,
metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
)
ebft_diversity_coef: float = field(
default=1.0,
metadata={"help": "Coefficient for diversity penalty"},
)
ebft_ce_coef: float = field(
default=0.0,
metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
)
ebft_adaptive_max_tokens: bool = field(
default=True,
metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
)
ebft_gt_length_multiplier: float = field(
default=1.5,
metadata={
"help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
},
)
# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
@dataclass
class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
"""EBFT config for structured QA data — extends GRPOConfig."""
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
},
)
# -- Async structured mode: extends FastAsyncGRPOConfig --
@dataclass
class AxolotlAsyncEBFTConfig(
EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
):
"""EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
Includes all async fields: async_prefetch, vllm_lora_sync,
skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
"""
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
},
)
# -- Strided mode: extends TrainingArguments for unstructured text --
@dataclass
class AxolotlStridedEBFTConfig(
EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
):
"""EBFT config for unstructured text with strided block-parallel generation."""
ebft_stride: int = field(
default=8,
metadata={"help": "Stride between anchor points (in tokens)"},
)
ebft_context_length: int = field(
default=8,
metadata={"help": "Context window size for each block"},
)
ebft_generate_max_len: int = field(
default=8,
metadata={"help": "Number of tokens to generate per block"},
)
ebft_n_samples_per_prompt: int = field(
default=4,
metadata={"help": "Number of independent rollouts per document"},
)
ebft_temperature: float = field(
default=0.6,
metadata={"help": "Sampling temperature for strided generation"},
)
ebft_top_p: float = field(
default=1.0,
metadata={"help": "Top-p nucleus sampling threshold"},
)
ebft_rl_coef: float = field(
default=1.0,
metadata={"help": "RL policy gradient loss coefficient"},
)
ebft_advantage_estimator: str = field(
default="rloo",
metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
)
ebft_min_completion_prefix: int = field(
default=0,
metadata={"help": "Minimum tokens into completion before placing anchors"},
)

View File

@@ -0,0 +1,308 @@
"""
Fused Triton kernels for strided EBFT.
These kernels eliminate intermediate tensor materializations that dominate
the elementwise/fill category (~40% of CUDA time in profiling).
Kernels:
1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
"""
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# 1. Fused log_softmax + gather (selective log softmax)
# ---------------------------------------------------------------------------
# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
# This avoids materializing the full (B, S, V) log_softmax output.
@triton.jit
def _fused_log_softmax_gather_kernel(
logits_ptr, # (B*S, V) row-major
labels_ptr, # (B*S,) int64
output_ptr, # (B*S,) float32
V: tl.constexpr, # vocab size
BLOCK_V: tl.constexpr, # tile width over vocab
):
"""Compute log_softmax(logits)[label] for each row without materializing full output."""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + row * V
label = tl.load(labels_ptr + row)
# Pass 1: find max for numerical stability
max_val = -float("inf")
for v_start in range(0, V, BLOCK_V):
v_offsets = v_start + tl.arange(0, BLOCK_V)
mask = v_offsets < V
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
max_val = tl.maximum(max_val, tl.max(vals, axis=0))
# Pass 2: compute sum(exp(x - max))
sum_exp = 0.0
for v_start in range(0, V, BLOCK_V):
v_offsets = v_start + tl.arange(0, BLOCK_V)
mask = v_offsets < V
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
log_sum_exp = tl.log(sum_exp)
# Gather: log_softmax[label] = logits[label] - max - log_sum_exp
target_logit = tl.load(logits_row_ptr + label)
result = target_logit - max_val - log_sum_exp
tl.store(output_ptr + row, result)
def fused_log_softmax_gather(
logits: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
Args:
logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
labels: (B, S) or (B*S,) int64 tensor of target indices
Returns:
(B, S) or (B*S,) float32 tensor of selected log probabilities
"""
orig_shape = logits.shape[:-1]
V = logits.shape[-1]
logits_2d = logits.reshape(-1, V).contiguous()
labels_1d = labels.reshape(-1).contiguous()
n_rows = logits_2d.shape[0]
output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
# Choose BLOCK_V: must be power of 2, large enough for good occupancy
BLOCK_V = min(triton.next_power_of_2(V), 65536)
_fused_log_softmax_gather_kernel[(n_rows,)](
logits_2d,
labels_1d,
output,
V=V,
BLOCK_V=BLOCK_V,
)
return output.view(orig_shape)
# ---------------------------------------------------------------------------
# 2. Fused masked REINFORCE loss reduction
# ---------------------------------------------------------------------------
# Instead of: (-logp * adv * mask).sum() / mask.sum()
# We do the masked multiply-accumulate in one kernel, returning (sum, count).
@triton.jit
def _fused_reinforce_loss_kernel(
logps_ptr, # (N,) float32 per-token log probs
advs_ptr, # (N,) float32 advantages
mask_ptr, # (N,) bool action mask
partial_sum_ptr, # (n_blocks,) partial sums
partial_cnt_ptr, # (n_blocks,) partial counts
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
block_id = tl.program_id(0)
offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
valid = offsets < N
logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
# -logp * advantage * mask
loss = -logps * advs * m
block_sum = tl.sum(loss, axis=0)
block_cnt = tl.sum(m, axis=0)
tl.store(partial_sum_ptr + block_id, block_sum)
tl.store(partial_cnt_ptr + block_id, block_cnt)
def fused_reinforce_loss(
per_token_logps: torch.Tensor,
advantages: torch.Tensor,
action_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
All inputs should be flat or will be flattened. Returns scalar loss tensor.
"""
logps_flat = per_token_logps.reshape(-1).contiguous()
advs_flat = advantages.reshape(-1).contiguous()
mask_flat = action_mask.reshape(-1).contiguous()
N = logps_flat.shape[0]
BLOCK_N = 1024
n_blocks = triton.cdiv(N, BLOCK_N)
partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
_fused_reinforce_loss_kernel[(n_blocks,)](
logps_flat,
advs_flat,
mask_flat,
partial_sum,
partial_cnt,
N=N,
BLOCK_N=BLOCK_N,
)
total_sum = partial_sum.sum()
total_cnt = partial_cnt.sum().clamp(min=1)
return total_sum / total_cnt
# ---------------------------------------------------------------------------
# 3. Fused cosine similarity (batched, for EBFT rewards)
# ---------------------------------------------------------------------------
# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
# we fuse the dot product, norm computation, and division into one kernel.
@triton.jit
def _fused_cosine_sim_kernel(
a_ptr, # (N, D) first set of vectors
b_ptr, # (N, D) second set of vectors
out_ptr, # (N,) cosine similarities
D: tl.constexpr,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
a_row_ptr = a_ptr + row * D
b_row_ptr = b_ptr + row * D
dot = 0.0
norm_a = 0.0
norm_b = 0.0
for d_start in range(0, D, BLOCK_D):
d_offsets = d_start + tl.arange(0, BLOCK_D)
mask = d_offsets < D
a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
dot += tl.sum(a_vals * b_vals, axis=0)
norm_a += tl.sum(a_vals * a_vals, axis=0)
norm_b += tl.sum(b_vals * b_vals, axis=0)
denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
denom = tl.where(denom > 1e-8, denom, 1e-8)
result = dot / denom
tl.store(out_ptr + row, result)
def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Compute cosine similarity along the last dimension.
Args:
a, b: (..., D) tensors of the same shape
Returns:
(...,) tensor of cosine similarities
"""
orig_shape = a.shape[:-1]
D = a.shape[-1]
a_2d = a.reshape(-1, D).contiguous()
b_2d = b.reshape(-1, D).contiguous()
N = a_2d.shape[0]
output = torch.empty(N, device=a.device, dtype=torch.float32)
BLOCK_D = min(triton.next_power_of_2(D), 4096)
_fused_cosine_sim_kernel[(N,)](
a_2d,
b_2d,
output,
D=D,
BLOCK_D=BLOCK_D,
)
return output.view(orig_shape)
# ---------------------------------------------------------------------------
# 4. Fused pairwise diversity penalty
# ---------------------------------------------------------------------------
# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
# We compute the pairwise dot products and exclusion in one kernel.
@triton.jit
def _fused_diversity_kernel(
emb_ptr, # (B, N, D) embeddings, row-major
out_ptr, # (B, N) diversity penalties
N: tl.constexpr, # n_samples
D: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""For each (b, i), compute mean dot product to all j != i."""
b = tl.program_id(0)
i = tl.program_id(1)
# Pointer to emb[b, i, :]
emb_bi_ptr = emb_ptr + (b * N + i) * D
total_sim = 0.0
for j in range(N):
emb_bj_ptr = emb_ptr + (b * N + j) * D
dot = 0.0
for d_start in range(0, D, BLOCK_D):
d_offsets = d_start + tl.arange(0, BLOCK_D)
d_mask = d_offsets < D
a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
tl.float32
)
b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
tl.float32
)
dot += tl.sum(a_vals * b_vals, axis=0)
# Exclude self-similarity (j == i)
is_other = j != i
total_sim += dot * is_other
result = total_sim / (N - 1)
tl.store(out_ptr + b * N + i, result)
def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
"""Compute mean pairwise dot product (excluding self) per sample.
Args:
embeddings: (B, N, D) tensor where N is n_samples
Returns:
(B, N) tensor of diversity penalties
"""
B, N, D = embeddings.shape
embeddings = embeddings.contiguous()
output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
if N <= 1:
return output # diversity is 0 when there's only one sample
BLOCK_D = min(triton.next_power_of_2(D), 4096)
_fused_diversity_kernel[(B, N)](
embeddings,
output,
N=N,
D=D,
BLOCK_D=BLOCK_D,
)
return output

View File

@@ -0,0 +1,264 @@
"""
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@torch.no_grad()
def extract_hidden_states(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
layer_indices: list[int],
batch_size: int | None = None,
) -> torch.Tensor:
"""
Forward pass through model, extracting and concatenating hidden states
at specified layer indices.
Args:
model: The frozen feature network
input_ids: (B, S) token ids
attention_mask: (B, S) attention mask
layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
batch_size: If set, process in chunks to reduce peak memory
Returns:
Concatenated hidden states: (B, S, num_layers * H)
"""
if batch_size is None:
batch_size = input_ids.shape[0]
# Use the inner transformer body (skips lm_head) when available.
# This avoids the expensive hidden_dim × vocab_size matmul whose
# output (logits) is never used — only hidden_states are needed.
body = getattr(model, "model", None)
if body is not None and hasattr(body, "forward"):
forward_model = body
else:
forward_model = model
all_features = []
for i in range(0, input_ids.shape[0], batch_size):
chunk_ids = input_ids[i : i + batch_size]
chunk_mask = attention_mask[i : i + batch_size]
outputs = forward_model(
chunk_ids,
attention_mask=chunk_mask,
output_hidden_states=True,
return_dict=True,
)
# hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
# index 0 is the embedding layer output
hidden_states = outputs.hidden_states
chunk_features = []
for idx in layer_indices:
chunk_features.append(hidden_states[idx])
# Concatenate across feature dimension: (B, S, num_layers * H)
all_features.append(torch.cat(chunk_features, dim=-1))
return torch.cat(all_features, dim=0)
def apply_embed_method(
hidden_states: torch.Tensor,
method: str,
attention_mask: torch.Tensor | None = None,
prompt_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Pool per-token hidden states into per-sequence embeddings.
Args:
hidden_states: (B, S, D) concatenated hidden states
method: One of "last_token", "mean_pooling", "completion_mean", "concat"
attention_mask: (B, S) mask for mean pooling
prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
Returns:
Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
(B, 3*D) for concat
"""
if method == "last_token":
if attention_mask is not None:
# Find last non-padding position per sample
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
return hidden_states[:, -1, :]
if method == "mean_pooling":
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
return hidden_states.mean(dim=1)
if method == "completion_mean":
# Mean pool over completion tokens only (exclude prompt)
if prompt_lengths is None:
raise ValueError("completion_mean requires prompt_lengths")
B, S, _ = hidden_states.shape
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
if attention_mask is not None:
comp_mask = comp_mask & attention_mask.bool()
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
if method == "concat":
B, S, D = hidden_states.shape
if attention_mask is not None:
valid_lens = attention_mask.sum(dim=1).long() # (B,)
else:
valid_lens = torch.full(
(B,), S, device=hidden_states.device, dtype=torch.long
)
# Compute quartile positions relative to valid length per sample
# First valid position index for each sample (handles right-padding)
q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
batch_idx = torch.arange(B, device=hidden_states.device)
return torch.cat(
[
hidden_states[batch_idx, q1],
hidden_states[batch_idx, q2],
hidden_states[batch_idx, q3],
],
dim=-1,
)
raise ValueError(f"Unknown embed_method: {method}")
@torch.no_grad()
def get_alignment_rewards(
gen_embedding: torch.Tensor,
gt_embedding: torch.Tensor,
) -> torch.Tensor:
"""
Compute alignment reward as cosine similarity between generated
and ground-truth feature embeddings.
Args:
gen_embedding: (B, D) generated sequence embeddings
gt_embedding: (B, D) ground-truth sequence embeddings
If num_generations > 1, gt_embedding should be repeated
to match gen_embedding's batch dim.
Returns:
Alignment rewards: (B,) cosine similarities in [-1, 1]
"""
return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
@torch.no_grad()
def get_diversity_rewards(
gen_embedding: torch.Tensor,
num_generations: int,
) -> torch.Tensor:
"""
Compute diversity penalty as mean pairwise dot-product similarity
between samples from the same prompt (excluding self-similarity).
Args:
gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
num_generations: Number of generations per prompt
Returns:
Diversity penalties: (B,) mean similarity to other samples from same prompt
"""
if num_generations <= 1:
return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
num_prompts = gen_embedding.shape[0] // num_generations
# Reshape to (num_prompts, num_generations, D)
reshaped = gen_embedding.view(num_prompts, num_generations, -1)
# Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
# Zero out self-similarity (diagonal)
eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
# Mean similarity to other samples: (num_prompts, num_generations)
diversity = sims.sum(dim=-1) / (num_generations - 1)
# Flatten back to (B,)
return diversity.view(-1)
def whiten_embeddings_batched(
phi: torch.Tensor,
phi_gt: torch.Tensor,
whiten_tol: float = 1e-5,
normalize: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Whiten generated embeddings using SVD, then apply same transform to ground-truth.
Whitening decorrelates feature dimensions so no single direction dominates
the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
Note: Singular values scale with sqrt(B), so reward magnitudes are
batch-size dependent. This is acceptable because B = n_samples_per_prompt
which is fixed during training (typically 2-4).
Args:
phi: (B, D) generated embeddings (used to estimate covariance)
phi_gt: (B, D) ground-truth embeddings
whiten_tol: Tolerance for singular value cutoff
normalize: If True, L2-normalize after whitening
Returns:
Whitened (phi, phi_gt) tuple, each (B, D)
"""
phi_f = phi.float()
phi_gt_f = phi_gt.float()
# Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
try:
U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
except torch._C._LinAlgError:
# Fallback: add small noise
noise = 1e-6 * phi_f.abs().mean()
try:
U, S, _ = torch.linalg.svd(
(phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
full_matrices=False,
)
except torch._C._LinAlgError:
if normalize:
return (
F.normalize(phi, p=2, dim=-1),
F.normalize(phi_gt, p=2, dim=-1),
)
return phi, phi_gt
U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
# Safe inverse of singular values
s_max = S.max()
inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
# W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
if normalize:
phi_w = F.normalize(phi_w, p=2, dim=-1)
phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
return phi_w, phi_gt_w

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,531 @@
"""
EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
the standard GRPO reward function interface.
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
"""
import contextlib
import copy
from typing import TYPE_CHECKING, Any
import torch
from datasets import Dataset, IterableDataset
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
from axolotl.core.trainers.ebft.rewards import (
apply_embed_method,
extract_hidden_states,
get_alignment_rewards,
get_diversity_rewards,
whiten_embeddings_batched,
)
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOTrainer,
)
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from collections import defaultdict
from accelerate import Accelerator
from trl.generation.vllm_generation import VLLMGeneration
LOG = get_logger(__name__)
class EBFTMixin:
"""
Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
Provides:
- Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
- _feature_matching_reward() callable for GRPO reward function interface
- _sequential_rollout() for multi-turn conversations
"""
# Type stubs for attributes provided by the composed GRPOTrainer base class.
# These are not defined here but accessed via cooperative multiple inheritance.
if TYPE_CHECKING:
accelerator: Accelerator
model: PreTrainedModel
args: AxolotlEBFTConfig
processing_class: PreTrainedTokenizerBase
num_generations: int
vllm_generation: VLLMGeneration
_metrics: defaultdict
_tag_names = ["trl", "ebft", "axolotl"]
def __init__(
self,
model: str | PreTrainedModel,
args: AxolotlEBFTConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset
| IterableDataset
| dict[str, Dataset | IterableDataset]
| None = None,
processing_class: PreTrainedTokenizerBase | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: Any | None = None,
):
# Pass our feature-matching reward function to GRPOTrainer
# It will be called with (prompts, completions, **kwargs) where
# kwargs includes all extra dataset fields like "ground_truth"
super().__init__( # type: ignore[call-arg]
model=model,
reward_funcs=[self._feature_matching_reward],
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
assert args is not None
# --- Feature network setup ---
unwrapped = self.accelerator.unwrap_model(self.model)
# Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
unwrapped, "disable_adapter"
)
if self._share_feature_weights:
# Share weights: use actor's base model with adapters disabled.
# Saves a full model copy (~8 GB for 4B model).
self.feature_network = None
param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
LOG.info(
f"EBFT feature network shares actor weights (PEFT disable_adapter). "
f"Saving ~{param_gb:.1f} GB"
)
else:
LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
self.feature_network = copy.deepcopy(unwrapped)
for param in self.feature_network.parameters():
param.requires_grad = False
self.feature_network.eval()
# Compute layer indices from fractional depths
# Handle VLM models where num_hidden_layers is on text_config
config = unwrapped.config
if hasattr(config, "text_config") and hasattr(
config.text_config, "num_hidden_layers"
):
config = config.text_config
num_layers = config.num_hidden_layers
self.feature_layer_indices = [
int(frac * num_layers) for frac in args.ebft_feature_layers
]
LOG.info(
f"EBFT feature extraction from layers {self.feature_layer_indices} "
f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
)
if args.ebft_adaptive_max_tokens:
LOG.info(
f"EBFT adaptive max_tokens enabled "
f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
)
_adaptive_max_lock = None # initialized lazily
def _generate_only(self, inputs, rank0_only=False):
"""Override to set per-batch max_tokens based on ground-truth length.
Uses a lock to prevent race conditions in async mode where concurrent
BG threads could interleave mutations of max_completion_length.
"""
import threading
args = self.args
if (
args.ebft_adaptive_max_tokens
and hasattr(self, "vllm_generation")
and inputs
):
gt_texts = [
x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
]
if gt_texts:
gt_token_counts = [
len(self.processing_class.encode(gt, add_special_tokens=False))
for gt in gt_texts
]
multiplier = args.ebft_gt_length_multiplier
max_completion = self.vllm_generation.max_completion_length
adaptive_max = max(
min(int(c * multiplier), max_completion) for c in gt_token_counts
)
adaptive_max = max(adaptive_max, 64)
if self._adaptive_max_lock is None:
self._adaptive_max_lock = threading.Lock()
with self._adaptive_max_lock:
original = self.vllm_generation.max_completion_length
self.vllm_generation.max_completion_length = adaptive_max
try:
return super()._generate_only(inputs, rank0_only)
finally:
self.vllm_generation.max_completion_length = original
return super()._generate_only(inputs, rank0_only)
@torch.no_grad()
def _feature_matching_reward(
self,
prompts: list,
completions: list,
ground_truth: list[str] | None = None,
remaining_turns: list | None = None,
**kwargs,
) -> list[float]:
"""
Compute feature-matching rewards for generated completions.
This is called by GRPOTrainer's _generate_and_score_completions()
as a standard reward function. The `ground_truth` field comes from
the dataset via reward_kwargs.
For multi-turn conversations, `remaining_turns` contains the subsequent
user/assistant turn pairs. When present, we do sequential rollouts:
generate each assistant turn conditioned on history + previous generations,
then compute feature-matching rewards on the full generated conversation.
Args:
prompts: List of prompt strings/messages
completions: List of generated completion strings
ground_truth: List of reference completion strings (from dataset)
remaining_turns: List of remaining conversation turns after the
first assistant turn (for multi-turn rollouts)
Returns:
List of scalar rewards, one per completion
"""
if ground_truth is None:
LOG.warning("No ground_truth field in dataset — using zero rewards")
return [0.0] * len(prompts)
device = self.accelerator.device
args = self.args
num_gens = self.num_generations
# --- Multi-turn sequential rollout ---
# If remaining_turns is provided, generate subsequent assistant turns
# by calling vLLM for each turn, building up the full conversation.
if remaining_turns is not None and hasattr(self, "vllm_generation"):
completions = self._sequential_rollout(
prompts, completions, remaining_turns, num_gens
)
# --- Tokenize generated sequences: prompt + completion ---
gen_texts = []
gen_prompt_texts = []
for p, c in zip(prompts, completions, strict=True):
if isinstance(p, list):
prompt_text = self.processing_class.apply_chat_template(
p, tokenize=False, add_generation_prompt=True
)
else:
prompt_text = p
if isinstance(c, list):
comp_text = c[0].get("content", "") if c else ""
else:
comp_text = c
gen_texts.append(prompt_text + comp_text)
gen_prompt_texts.append(prompt_text)
gen_encoded = self.processing_class(
text=gen_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=getattr(self.args, "max_length", None)
or getattr(self.args, "max_seq_length", None)
or 2048,
add_special_tokens=False,
)
gen_ids = gen_encoded["input_ids"].to(device)
gen_mask = gen_encoded["attention_mask"].to(device)
# Compute prompt lengths for completion_mean pooling
gen_prompt_lengths = torch.tensor(
[
len(self.processing_class.encode(pt, add_special_tokens=False))
for pt in gen_prompt_texts
],
device=device,
)
# --- Tokenize ground-truth sequences: prompt + ground_truth ---
# For multi-turn (remaining_turns present), render the full GT conversation
# through the chat template to preserve role markers between turns.
gt_texts = []
gt_prompt_texts = []
for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
if i % num_gens != 0:
continue # Only need one GT per prompt group
if isinstance(p, list):
prompt_text = self.processing_class.apply_chat_template(
p, tokenize=False, add_generation_prompt=True
)
# Multi-turn: build full GT conversation with remaining turns
if remaining_turns is not None:
prompt_idx = i // num_gens
turns = (
remaining_turns[prompt_idx]
if prompt_idx < len(remaining_turns)
else []
)
if turns:
gt_conv = list(p) + [{"role": "assistant", "content": gt}]
gt_conv.extend(turns)
full_gt_text = self.processing_class.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
gt_texts.append(full_gt_text)
gt_prompt_texts.append(prompt_text)
continue
else:
prompt_text = p
gt_texts.append(prompt_text + gt)
gt_prompt_texts.append(prompt_text)
gt_encoded = self.processing_class(
text=gt_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=getattr(self.args, "max_length", None)
or getattr(self.args, "max_seq_length", None)
or 2048,
add_special_tokens=False,
)
gt_ids = gt_encoded["input_ids"].to(device)
gt_mask = gt_encoded["attention_mask"].to(device)
gt_prompt_lengths = torch.tensor(
[
len(self.processing_class.encode(pt, add_special_tokens=False))
for pt in gt_prompt_texts
],
device=device,
)
# --- Extract features from frozen feature network ---
# INVARIANT: disable_adapter() yields the unmodified base weights because
# _sync_peft_weights_no_merge and _sync_lora_adapter never call
# merge_adapter() — they compute merged weights as new tensors or save
# the adapter to filesystem. Base weights are never modified in-place.
if self._share_feature_weights:
unwrapped = self.accelerator.unwrap_model(self.model)
feature_ctx = unwrapped.disable_adapter()
else:
unwrapped = self.feature_network
feature_ctx = contextlib.nullcontext()
with feature_ctx:
was_training = unwrapped.training
unwrapped.eval()
gen_hidden = extract_hidden_states(
unwrapped, gen_ids, gen_mask, self.feature_layer_indices
)
gt_hidden = extract_hidden_states(
unwrapped, gt_ids, gt_mask, self.feature_layer_indices
)
if was_training:
unwrapped.train()
# --- Pool to sequence-level embeddings ---
gen_emb = apply_embed_method(
gen_hidden,
args.ebft_embed_method,
gen_mask,
prompt_lengths=gen_prompt_lengths,
)
gt_emb = apply_embed_method(
gt_hidden,
args.ebft_embed_method,
gt_mask,
prompt_lengths=gt_prompt_lengths,
)
# --- Optional whitening ---
batch_size = gen_emb.shape[0]
if args.ebft_use_whitening and batch_size > 1:
num_prompts = batch_size // num_gens
gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
whitened_gen_list = []
whitened_gt_list = []
for i in range(num_prompts):
w_gen, w_gt = whiten_embeddings_batched(
gen_reshaped[i], gt_emb[i : i + 1]
)
whitened_gen_list.append(w_gen)
whitened_gt_list.append(w_gt)
gen_emb = torch.cat(whitened_gen_list, dim=0)
gt_emb = torch.cat(whitened_gt_list, dim=0)
else:
gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
# Repeat gt_emb: each GT repeated num_generations times
gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
# --- Compute rewards ---
alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
diversity = get_diversity_rewards(gen_emb, num_gens)
# Scale by 2 per paper equation (7):
# r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
alignment = alignment * 2
diversity = diversity * 2
rewards = (
args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
)
# Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
# Log feature-matching metrics to console and wandb
_align = alignment.mean().item()
_divers = diversity.mean().item()
_reward = rewards.mean().item()
_cfm = cfm_loss.item()
LOG.info(
f"ebft reward | "
f"align {_align:+.3f} ^ | "
f"divers {_divers:+.3f} v | "
f"cfm {_cfm:.3f} v | "
f"reward {_reward:+.3f} ^"
)
# Log to wandb via trainer's _metrics (picked up by GRPO's logging)
mode = "train" if self.model.training else "eval"
if hasattr(self, "_metrics"):
self._metrics[mode]["ebft/alignment"].append(_align)
self._metrics[mode]["ebft/diversity"].append(_divers)
self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
self._metrics[mode]["ebft/reward"].append(_reward)
return rewards.cpu().tolist()
@torch.no_grad()
def _sequential_rollout(
self,
prompts: list,
first_completions: list,
remaining_turns: list,
num_gens: int,
) -> list:
"""
Extend single-turn completions into multi-turn conversations.
For each prompt group, takes the first generated assistant turn and
sequentially generates subsequent assistant turns by calling vLLM,
building up a full multi-turn conversation.
Args:
prompts: List of prompt message lists (repeated num_gens times)
first_completions: List of generated first-turn completions
remaining_turns: List of remaining turn pairs after first assistant turn.
Each element is a list of dicts: [{"role": "user", "content": "..."},
{"role": "assistant", "content": "...GT..."}]
num_gens: Number of generations per prompt
Returns:
Extended completions incorporating all generated turns
"""
vllm_client = self.vllm_generation.vllm_client
max_tokens = getattr(self.args, "max_completion_length", 256)
temperature = getattr(self.args, "temperature", 0.7)
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
extended_completions = []
for idx in range(len(prompts)):
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
first_comp = first_completions[idx]
# Extract first completion text
if isinstance(first_comp, list):
first_text = first_comp[0].get("content", "") if first_comp else ""
else:
first_text = first_comp
# Get remaining turns for this prompt (same for all num_gens copies)
prompt_idx = idx // num_gens
turns = (
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
)
if not turns:
extended_completions.append(first_text)
continue
# Build conversation with generated first turn
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
# Generate subsequent turns
for turn in turns:
if turn["role"] == "user":
conv.append(turn)
elif turn["role"] == "assistant":
try:
result = vllm_client.chat(
messages=[conv],
n=1,
max_tokens=max_tokens,
temperature=temperature,
generation_kwargs=gen_kwargs,
)
gen_ids = result.get("completion_ids", [[]])[0]
gen_text = self.processing_class.decode(
gen_ids, skip_special_tokens=True
)
except Exception as e:
LOG.warning(f"Multi-turn rollout generation failed: {e}")
gen_text = ""
conv.append({"role": "assistant", "content": gen_text})
# Render full conversation through chat template, then extract
# everything after the original prompt as the "completion" text.
# This preserves role markers and formatting between turns.
full_rendered = self.processing_class.apply_chat_template(
conv, tokenize=False, add_generation_prompt=False
)
prompt_rendered = self.processing_class.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
completion_text = full_rendered[len(prompt_rendered) :]
extended_completions.append(completion_text)
return extended_completions
class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
"""EBFT trainer using synchronous GRPO (standard vLLM generation)."""
pass
class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
"""EBFT trainer using async GRPO (prefetches next batch during training)."""
pass

View File

@@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer):
"""
def __init__(self, *args, **kwargs):
# When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.
# The communicator is not needed because weight sync happens via filesystem + HTTP,
# and it fails when vLLM and a trainer rank share the same CUDA device.
# Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
# merged weight sync. NCCL is only needed for the standard update_named_param
# path which broadcasts tensors through the communicator.
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
if training_args is not None and getattr(
training_args, "vllm_lora_sync", False
):
_skip_nccl = False
if training_args is not None:
if getattr(training_args, "vllm_lora_sync", False):
_skip_nccl = True # LoRA sync uses filesystem + HTTP
elif getattr(training_args, "async_prefetch", False):
# Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
# init_communicator allocates device tensors on rank 0 only, which
# causes DDP to see different param counts across ranks.
# The communicator is initialized lazily on first weight sync instead.
_skip_nccl = True
if _skip_nccl:
from trl.generation.vllm_generation import VLLMGeneration
_orig_init_vllm = VLLMGeneration._init_vllm
@@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
VLLMGeneration._init_vllm = _init_vllm_no_communicator
super().__init__(*args, **kwargs)
try:
super().__init__(*args, **kwargs)
finally:
# Restore original _init_vllm so other trainers aren't affected
if _skip_nccl:
VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
# FP8 models: zero out the pad token embedding so that padding
# positions have zero hidden states throughout the network.
@@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._executor = None
def _submit_generation(self):
"""Submit the next background generation job."""
"""Submit the next background generation job.
With multi-process (DDP/FSDP), only rank 0 generates to avoid
cross-rank NCCL collectives from background threads. Non-rank-0
processes enqueue a sentinel ``None`` that is replaced by a
broadcast in ``_prepare_inputs_legacy_async``.
"""
rank0_only = self.accelerator.num_processes > 1
if rank0_only and not self.accelerator.is_main_process:
# Non-rank-0: nothing to generate; enqueue a resolved None future
f: concurrent.futures.Future = concurrent.futures.Future()
f.set_result(None)
self._async_queue.put(f)
return
batch = next(self._prompt_iter)
future = self._executor.submit(self._generate_only, batch)
future = self._executor.submit(self._generate_only, batch, rank0_only)
self._async_queue.put(future)
# ------------------------------------------------------------------
# Broadcast rollout (legacy async, multi-process)
# ------------------------------------------------------------------
def _broadcast_rollout(self, rollout: dict | None) -> dict:
"""Broadcast a rank0-only rollout dict to all ranks (main thread).
Rank 0 has the full rollout dict from ``_generate_only``; other ranks
have ``None``. After broadcast, tensors are moved to each rank's
local device.
"""
import torch.distributed as dist
obj_list = [rollout if self.accelerator.is_main_process else None]
dist.broadcast_object_list(obj_list, src=0)
rollout = obj_list[0]
assert rollout is not None, "broadcast_object_list failed to deliver rollout"
# Move tensors to local device (broadcast deserializes to CPU)
device = self.accelerator.device
for key, val in rollout.items():
if isinstance(val, torch.Tensor) and val.device != device:
rollout[key] = val.to(device)
return rollout
# ------------------------------------------------------------------
# Weight sync
# ------------------------------------------------------------------
@@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer):
for Float8), and also safe for concurrent use since it never modifies base
weights in-place.
"""
model = self.vllm_generation.model
accelerator = self.vllm_generation.accelerator
vllm_client = self.vllm_generation.vllm_client
fix_name = self.vllm_generation._fix_param_name_to_vllm
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
return
# In multi-GPU async mode, we skip NCCL communicator init to avoid
# DDP param count mismatch and NCCL device conflicts. Weight sync
# uses the HTTP-only fallback in batch_update_named_params instead.
model = self.vllm_generation.model
vllm_client = self.vllm_generation.vllm_client
fix_name = self.vllm_generation._fix_param_name_to_vllm
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
lora_info = {}
for mod_name, module in model.base_model.model.named_modules():
@@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
weight_name = pname.replace(".weight_scale_inv", ".weight")
scale_inv_lookup[weight_name] = pparam.data
# Iterate all parameters, computing merged weights for LoRA layers.
# Skip LoRA-specific params and FP8 scale params (scales will be
# recomputed by vLLM when it receives the merged bf16 weight).
# Only sync parameters that have LoRA modifications — skip unchanged
# base weights to avoid OOM on the vLLM GPU from allocating the entire
# model's worth of NCCL receive buffers.
params_to_sync = []
compute_dtype = torch.bfloat16
for name, param in model.named_parameters():
vllm_name = name.removeprefix("base_model.model.").replace(
".base_layer", ""
@@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer):
continue
if "original_module" in vllm_name:
continue
# Skip FP8 quantization scale parameters - they are recomputed
# on the vLLM side when we update the weight itself
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
continue
if not vllm_name.endswith(".weight"):
continue
# fix_name strips modules_to_save.default. prefix
raw_mod_path = vllm_name[: -len(".weight")]
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
mod_path = vllm_name[: -len(".weight")]
# Sync weights that have LoRA adapters OR are modules_to_save
is_lora = mod_path in lora_info
is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
if not is_lora and not is_modules_to_save:
continue
data = param.data
compute_dtype = torch.bfloat16
if vllm_name.endswith(".weight"):
# Dequantize FP8 weights before merging
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
scale_inv = scale_inv_lookup[name]
# Block dequantization: weight * scale_inv (with broadcasting)
fp8_bf16 = data.to(compute_dtype)
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
# Block-quantized: scale_inv shape (rows/block, cols/block)
sr, sc = scale_inv.shape
br = fp8_bf16.shape[0] // sr # block height
bc = fp8_bf16.shape[1] // sc # block width
# Reshape → multiply by block scale → reshape back
data = (
fp8_bf16.reshape(sr, br, sc, bc)
* scale_inv[:, None, :, None].to(compute_dtype)
).reshape(fp8_bf16.shape)
elif scale_inv.dim() <= 1:
# Per-tensor or per-channel scale
data = fp8_bf16 * scale_inv.to(compute_dtype)
else:
data = fp8_bf16
elif data.dtype == torch.float8_e4m3fn:
# FP8 but no scale found - just cast (lossy)
data = data.to(compute_dtype)
# Dequantize FP8 weights before merging
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
scale_inv = scale_inv_lookup[name]
fp8_bf16 = data.to(compute_dtype)
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
sr, sc = scale_inv.shape
br = fp8_bf16.shape[0] // sr
bc = fp8_bf16.shape[1] // sc
data = (
fp8_bf16.reshape(sr, br, sc, bc)
* scale_inv[:, None, :, None].to(compute_dtype)
).reshape(fp8_bf16.shape)
elif scale_inv.dim() <= 1:
data = fp8_bf16 * scale_inv.to(compute_dtype)
else:
data = fp8_bf16
elif data.dtype == torch.float8_e4m3fn:
data = data.to(compute_dtype)
mod_path = vllm_name[: -len(".weight")]
if mod_path in lora_info:
A, B, s = lora_info[mod_path]
merged = data.to(compute_dtype) + s * (
B.to(compute_dtype) @ A.to(compute_dtype)
)
data = merged
if is_lora:
A, B, s = lora_info[mod_path]
merged = data.to(compute_dtype) + s * (
B.to(compute_dtype) @ A.to(compute_dtype)
)
params_to_sync.append((vllm_name, merged))
else:
# modules_to_save: send raw weight (no LoRA merge needed)
params_to_sync.append((vllm_name, data.to(compute_dtype)))
params_to_sync.append((vllm_name, data))
# Batch sync all params in one HTTP+NCCL call (vs individual calls)
# Batch sync only LoRA-modified params via HTTP+NCCL
if params_to_sync:
sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
logger.info(
f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
)
vllm_client.batch_update_named_params(params_to_sync)
# Reset prefix cache after weight update
@@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
vllm_client = self.vllm_generation.vllm_client
url = f"{vllm_client.base_url}/set_lora_adapter/"
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
response = requests.post(
url,
json={
@@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
"lora_int_id": self._lora_sync_version,
"lora_path": adapter_path,
},
timeout=30,
timeout=sync_timeout,
)
if response.status_code != 200:
logger.warning(
@@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
step = self.state.global_step
interval = self.args.vllm_sync_interval
if step != self._last_synced_step and step % interval == 0:
if step == 0:
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
self._last_synced_step = step
return
if getattr(self.args, "vllm_lora_sync", False):
if step == 0:
logger.info("Skipping LoRA sync at step 0 (no training yet)")
self._last_synced_step = step
return
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly
self._sync_lora_adapter()
else:
@@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Background-thread generation (no scoring)
# ------------------------------------------------------------------
def _generate_single_turn(self, prompts, **kwargs):
def _generate_single_turn(self, prompts, *args, **kwargs):
"""Override to prevent weight sync from background thread and to use
no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
is_bg = threading.current_thread() is not threading.main_thread()
@@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._patched_sync_weights = True
try:
return super()._generate_single_turn(prompts, **kwargs)
return super()._generate_single_turn(prompts, *args, **kwargs)
finally:
if saved_step is not None:
self._last_loaded_step = saved_step
@@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
output = vg.vllm_client.chat(
messages=unique_prompts,
**sampling_params,
chat_template_kwargs=vg.chat_template_kwargs,
tools=vg.tools,
chat_template=vg.chat_template,
chat_template_kwargs=self.chat_template_kwargs,
tools=self.tools,
chat_template=getattr(self, "chat_template", None),
)
else:
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
@@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
logps_diff = per_token_logps_diff
is_ratio = torch.exp(logps_diff)
is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
if is_mode in ("sequence_truncate", "token_truncate"):
is_ratio = torch.clamp(is_ratio, max=is_cap)
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
is_ratio = is_ratio.clamp(min=is_floor)
data["importance_sampling_ratio"] = is_ratio
# --- Collect rewards (launched before logprobs, should be done) ---
@@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
seq_is = is_mode in ("sequence_mask", "sequence_truncate")
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
is_ratio = torch.exp(logps_diff)
# Symmetric floor clamp (matches non-streaming path at line ~1651)
is_floor = 1.0 / is_cap
if is_mode in ("sequence_truncate", "token_truncate"):
is_ratio = torch.clamp(is_ratio, max=is_cap)
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
is_ratio = is_ratio.clamp(min=is_floor)
if "importance_sampling_ratio" not in data:
total = len(data["prompt_ids"])
shape = (total, 1) if seq_is else (total, is_ratio.size(1))
@@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
rollout = future.result()
self._submit_generation()
# With multi-process, only rank 0 generated. Broadcast to all ranks.
if self.accelerator.num_processes > 1:
rollout = self._broadcast_rollout(rollout)
if self.args.streaming_partial_batch:
micro_batches = self._score_streaming(rollout)
else:

View File

@@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback):
logger.info("=" * 60)
if self.trainer.axolotl_cfg.use_wandb:
if wandb.run is not None:
wandb.log(
if wandb.run is not None: # type: ignore[attr-defined]
wandb.log( # type: ignore[attr-defined]
{
"generated_samples": wandb.Table(
"generated_samples": wandb.Table( # type: ignore[attr-defined]
columns=[
"step",
"original",

View File

@@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__)
def _batch_update_named_params(
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
):
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
from transformers import is_torch_xpu_available
"""Batched weight sync — uses NCCL if communicator available, HTTP otherwise."""
has_communicator = getattr(self, "communicator", None) is not None
if chunk_size is None:
chunks = [params]
else:
chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
if has_communicator:
# Fast path: metadata via HTTP, tensors via NCCL
from transformers import is_torch_xpu_available
for chunk in chunks:
param_metadata = [
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
if chunk_size is None:
chunks = [params]
else:
self.communicator.group.barrier()
chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
param_metadata = [
{
"name": name,
"dtype": str(weights.dtype),
"shape": list(weights.shape),
}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(
url, json={"params": param_metadata}, timeout=120
)
if response.status_code != 200:
raise Exception(
f"Request failed: {response.status_code}, {response.text}"
)
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
else:
# HTTP-only path: encode tensor data in request body (no NCCL needed).
# Batch by byte size to avoid huge HTTP payloads.
MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB
HTTP_TIMEOUT = 120 # seconds per request
payload: list[dict] = []
payload_bytes = 0
url = f"{self.base_url}/http_update_weights/"
def _flush(p: list[dict]) -> None:
if not p:
return
response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT)
if response.status_code != 200:
raise Exception(
f"Request failed: {response.status_code}, {response.text}"
)
from axolotl.utils.weight_serde import encode_for_http
for name, weights in params:
entry = encode_for_http(name, weights)
entry_bytes = weights.nelement() * weights.element_size()
# Flush current batch if adding this entry would exceed limit
if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST:
_flush(payload)
payload = []
payload_bytes = 0
payload.append(entry)
payload_bytes += entry_bytes
_flush(payload) # send remaining
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):

View File

@@ -0,0 +1,9 @@
"""
module for EBFT style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.ebft")

View File

@@ -0,0 +1,129 @@
"""
Dataset transform for multi-turn chat data with structured EBFT (vLLM mode).
Three variants:
1. `transform` — Uses the FIRST assistant turn as the generation target.
Passes remaining turns as `remaining_turns` for sequential rollout.
The trainer generates turn 1 via GRPO/vLLM, then sequentially generates
subsequent assistant turns, comparing the full conversation to GT.
2. `transform_last_turn` — Uses the LAST assistant turn as the target.
Simplest approach: the full conversation history is the prompt.
3. `transform_all_turns` — Explodes each conversation into N examples
(one per assistant turn). Each turn is an independent training example.
Use with batched=True.
Supports OpenAI chat format:
{"messages": [{"role": ..., "content": ...}, ...]}
"""
def transform(cfg, **kwargs):
"""Multi-turn with sequential rollout.
Returns the first assistant turn as ground_truth, plus remaining_turns
for the trainer to do sequential rollout generation.
"""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if not messages:
return {"prompt": [], "ground_truth": ""}
# Split at first assistant turn
prompt_msgs = []
first_gt = None
remaining = []
found_first = False
for msg in messages:
if msg["role"] == "assistant" and not found_first:
first_gt = msg["content"]
found_first = True
elif found_first:
remaining.append(msg)
else:
prompt_msgs.append(msg)
if first_gt is None:
return {"prompt": prompt_msgs, "ground_truth": ""}
# Store only the first assistant turn as ground_truth. The full multi-turn
# GT is reconstructed in the reward function via chat template rendering
# (using remaining_turns), which preserves role markers between turns.
return {
"prompt": prompt_msgs,
"ground_truth": first_gt,
"remaining_turns": remaining,
}
return transform_fn, {
"remove_columns": "__all__",
}
def transform_last_turn(cfg, **kwargs):
"""Single-turn: use the last assistant turn as the generation target."""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if not messages:
return {"prompt": [], "ground_truth": ""}
# Find all assistant turns
history = []
last_prompt = []
last_gt = ""
for msg in messages:
if msg["role"] == "assistant":
last_prompt = list(history)
last_gt = msg["content"]
history.append(msg)
return {
"prompt": last_prompt,
"ground_truth": last_gt,
}
return transform_fn, {
"remove_columns": "__all__",
}
def transform_all_turns(cfg, **kwargs):
"""Explode: one example per assistant turn.
Use with datasets.map(batched=True) to produce N examples from
each N-turn conversation.
Usage in YAML:
type: ebft_chat_multiturn.transform_all_turns
"""
def transform_fn(examples, tokenizer=None):
all_prompts = []
all_ground_truths = []
messages_list = examples.get("messages", examples.get("conversations", []))
for messages in messages_list:
history = []
for msg in messages:
if msg["role"] == "assistant":
all_prompts.append(list(history))
all_ground_truths.append(msg["content"])
history.append(msg)
return {
"prompt": all_prompts,
"ground_truth": all_ground_truths,
}
return transform_fn, {
"remove_columns": "__all__",
"batched": True,
}

View File

@@ -0,0 +1,20 @@
"""
Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode.
Maps the dataset's `input` (prompt) and `output` (code solution) fields
to the format expected by the EBFT trainer (prompt + ground_truth).
"""
def transform(cfg, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "user", "content": example["input"]},
],
"ground_truth": example["output"],
}
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,319 @@
"""
Dataset transform for reasoning/thinking datasets with EBFT.
Handles datasets where assistant responses contain <think>...</think> reasoning
traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs).
Two variants:
1. `transform` — For structured EBFT (vLLM mode):
Returns prompt + ground_truth with thinking tags preserved.
Feature matching compares full responses (thinking + answer).
2. `transform_answer_only` — For structured EBFT (vLLM mode):
Strips <think>...</think> from ground_truth, so feature matching
only scores the final answer portion. Use when reasoning chains
can vary but the answer should match.
3. `transform_strided` — For strided EBFT:
Tokenizes the full conversation with thinking traces.
Optionally masks thinking tokens from CE loss (labels=-100 for think spans)
while still placing anchors in thinking regions for feature matching.
All variants work with OpenAI chat format:
{"messages": [{"role": "...", "content": "<think>...</think>Answer"}]}
"""
import re
def _strip_thinking(text: str) -> str:
"""Remove <think>...</think> blocks from text."""
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
def _extract_thinking(text: str) -> tuple[str, str]:
"""Split text into (thinking, answer) parts."""
match = re.search(r"<think>(.*?)</think>\s*(.*)", text, flags=re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
return "", text.strip()
def transform(cfg, **kwargs):
"""Full response including thinking traces for feature matching.
For datasets where assistant content has <think>...</think> tags in the
content field. The ground_truth includes the full content (thinking + answer).
"""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
prompt_msgs_snapshot = None
ground_truth = ""
for msg_idx, msg in enumerate(messages):
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(messages[:msg_idx])
ground_truth = msg["content"]
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_split_thinking(cfg, **kwargs):
"""Split <think> tags into reasoning_content field for native chat template handling.
For datasets where thinking is embedded in the content field as <think>...</think>.
Splits it into separate reasoning_content and content fields so the model's
chat template can format it natively (e.g., Qwen3.5's reasoning_content support).
The prompt messages are passed through with reasoning_content properly split,
so vLLM generation with enable_thinking=true produces comparable outputs.
The ground_truth is the full assistant response (thinking + answer) for
feature matching.
Also works for:
- <reasoning>...</reasoning> tags
- <|begin_of_thought|>...<|end_of_thought|> tags
"""
_THINKING_PAIRS = [
("<think>", "</think>"),
("<reasoning>", "</reasoning>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
]
def _split_msg_thinking(msg):
"""Split thinking from assistant message content into reasoning_content.
Always includes reasoning_content key on assistant messages (empty string
if no thinking tags found) to ensure consistent HF dataset schema across
all examples in a batch.
"""
if msg["role"] != "assistant":
return msg
content = msg.get("content", "")
# Already has reasoning_content — pass through
if "reasoning_content" in msg:
return msg
for open_tag, close_tag in _THINKING_PAIRS:
if open_tag in content and close_tag in content:
start = content.find(open_tag)
end = content.find(close_tag)
thinking = content[start + len(open_tag) : end].strip()
answer = content[end + len(close_tag) :].strip()
return {
**msg,
"reasoning_content": thinking,
"content": answer,
}
# No thinking tags — still add reasoning_content for schema consistency
return {**msg, "reasoning_content": ""}
def _normalize_msg(msg):
"""Ensure every message has {role, content, reasoning_content} for HF schema consistency."""
return {
"role": msg.get("role", ""),
"content": msg.get("content", ""),
"reasoning_content": msg.get("reasoning_content", ""),
}
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
# Split thinking in all assistant messages, then normalize schema
split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages]
# Build prompt (all messages except last assistant) and ground_truth
prompt_msgs = []
prompt_msgs_snapshot = None
ground_truth = ""
for msg in split_messages:
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(prompt_msgs)
# ground_truth is the FULL content for feature matching
thinking = msg.get("reasoning_content", "")
answer = msg.get("content", "")
if thinking:
ground_truth = f"<think>\n{thinking}\n</think>\n\n{answer}"
else:
ground_truth = answer
prompt_msgs.append(msg)
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else split_messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_answer_only(cfg, **kwargs):
"""Strip thinking from ground_truth — match features on answer only."""
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
prompt_msgs = []
prompt_msgs_snapshot = None
ground_truth = ""
for msg in messages:
if msg["role"] == "assistant":
prompt_msgs_snapshot = list(prompt_msgs)
ground_truth = _strip_thinking(msg["content"])
prompt_msgs.append(msg)
return {
"prompt": prompt_msgs_snapshot
if prompt_msgs_snapshot is not None
else messages[:-1],
"ground_truth": ground_truth,
}
return transform_fn, {"remove_columns": "__all__"}
def transform_strided(cfg, **kwargs):
"""For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss.
Config options (via cfg):
- ebft.mask_thinking_ce: bool (default False)
If True, set labels=-100 for tokens inside <think>...</think> blocks.
Feature matching still uses these positions (anchors are placed everywhere
in the completion span). Only CE auxiliary loss is affected.
"""
seq_len = cfg.sequence_len
mask_thinking = False
if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"):
mask_thinking = cfg.ebft.mask_thinking_ce
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if tokenizer is None:
for m in messages:
if m.get("role") == "user":
return {"prompt": m["content"]}
return {"prompt": str(messages)}
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
# Tokenize the full conversation with the chat template
full_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
full_enc = tokenizer(
full_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)
input_ids = full_enc["input_ids"]
# Build labels: -100 for non-assistant tokens
labels = [-100] * len(input_ids)
# Find assistant turn boundaries using incremental tokenization.
# Only the FINAL assistant turn is marked as trainable.
prefix_messages = []
final_start = None
final_end = None
for msg in messages:
if msg["role"] == "assistant":
prefix_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=True,
)
prefix_ids = tokenizer(
prefix_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
start = len(prefix_ids)
prefix_messages.append(msg)
with_turn_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=False,
)
with_turn_ids = tokenizer(
with_turn_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
end = len(with_turn_ids)
# Record this turn's boundaries; only the last one will be used
final_start = start
final_end = end
else:
prefix_messages.append(msg)
# Mark only the final assistant turn as trainable
if final_start is not None and final_end is not None:
for i in range(final_start, min(final_end, len(labels))):
labels[i] = input_ids[i]
# Optionally mask <think>...</think> tokens within this turn.
# Find think spans by scanning for <think> and </think> token IDs
# directly in the input_ids (robust to tokenization alignment).
if mask_thinking:
think_open_id = tokenizer.convert_tokens_to_ids("<think>")
think_close_id = tokenizer.convert_tokens_to_ids("</think>")
if think_open_id != tokenizer.unk_token_id:
# Scan from before the assistant turn start to catch
# <think> tags that are part of the template prefix
scan_start = max(0, final_start - 5)
in_think = False
for i in range(scan_start, min(final_end, len(labels))):
if input_ids[i] == think_open_id:
in_think = True
if in_think and i >= final_start:
labels[i] = -100
if input_ids[i] == think_close_id:
in_think = False
if i >= final_start:
labels[i] = -100
# Derive prompt_length
prompt_length = len(input_ids)
for i, lbl in enumerate(labels):
if lbl != -100:
prompt_length = i
break
# Pad
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
return transform_fn, {"remove_columns": "__all__"}

View File

@@ -0,0 +1,110 @@
"""
Dataset transform for multi-turn chat data with strided EBFT.
Tokenizes conversations using the model's chat template, producing input_ids
with labels=-100 for system/user turns and real labels for assistant turns.
The strided trainer places anchors only within assistant completion spans.
Works with datasets in OpenAI chat format:
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
"""
def transform(cfg, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
messages = example.get("messages", example.get("conversations", []))
if tokenizer is None:
# For preview: just return the first user message
for m in messages:
if m.get("role") == "user":
return {"prompt": m["content"]}
return {"prompt": str(messages)}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize the full conversation with the chat template
full_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
full_enc = tokenizer(
full_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)
input_ids = full_enc["input_ids"]
# Build labels: -100 for everything except assistant turns.
# Strategy: tokenize incrementally to find assistant turn boundaries.
labels = [-100] * len(input_ids)
# Tokenize prefix up to each assistant turn to find boundaries
prefix_messages = []
for msg in messages:
if msg["role"] == "assistant":
# Tokenize prefix (everything before this assistant turn + generation prompt)
prefix_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=True,
)
prefix_ids = tokenizer(
prefix_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
start = len(prefix_ids)
# Tokenize prefix + this assistant turn
prefix_messages.append(msg)
with_turn_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=False,
)
with_turn_ids = tokenizer(
with_turn_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
end = len(with_turn_ids)
# Mark assistant tokens as trainable
for i in range(start, min(end, len(labels))):
labels[i] = input_ids[i]
else:
prefix_messages.append(msg)
# Derive prompt_length as the position of the first non-masked label
prompt_length = len(input_ids) # default: all masked
for i, lbl in enumerate(labels):
if lbl != -100:
prompt_length = i
break
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -0,0 +1,80 @@
"""
Dataset transform for structured (prompt, completion) data with strided EBFT.
Tokenizes prompt and completion separately, concatenates into a single
input_ids sequence, and marks prompt tokens with labels=-100 so the
strided trainer knows where to place anchors (completion span only).
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
"""
def transform(cfg, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
# Extract prompt and completion from the example
prompt_text = example.get(
"input", example.get("prompt", example.get("question", ""))
)
completion_text = example.get(
"output", example.get("completion", example.get("answer", ""))
)
if tokenizer is None:
return {"prompt": prompt_text}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize prompt and completion separately
prompt_enc = tokenizer(
prompt_text,
truncation=False,
add_special_tokens=True,
return_tensors=None,
)
completion_enc = tokenizer(
completion_text,
truncation=False,
add_special_tokens=False,
return_tensors=None,
)
prompt_ids = prompt_enc["input_ids"]
completion_ids = completion_enc["input_ids"]
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
total_len = len(prompt_ids) + len(completion_ids)
if total_len > seq_len:
# Truncate completion first, then prompt if needed
max_completion = seq_len - len(prompt_ids)
if max_completion < 1:
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
prompt_ids = prompt_ids[: seq_len - 1]
completion_ids = completion_ids[:1]
else:
completion_ids = completion_ids[:max_completion]
input_ids = prompt_ids + completion_ids
prompt_length = len(prompt_ids)
# Labels: -100 for prompt tokens, input_ids for completion tokens
labels = [-100] * prompt_length + completion_ids
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
# Signal to remove all original columns (filtered to existing ones at map time)
return transform_fn, {
"remove_columns": "__all__",
}

View File

@@ -241,6 +241,23 @@ def main(script_args: ScriptArguments):
app = FastAPI(lifespan=lifespan)
# --- Access logging middleware ---
import time as _time
@app.middleware("http")
async def access_log_middleware(request, call_next):
t0 = _time.monotonic()
response = await call_next(request)
elapsed = _time.monotonic() - t0
logger.info(
"%s %s %d %.3fs",
request.method,
request.url.path,
response.status_code,
elapsed,
)
return response
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
@@ -300,7 +317,11 @@ def main(script_args: ScriptArguments):
import vllm
from packaging.version import Version
from vllm.sampling_params import GuidedDecodingParams
try:
from vllm.sampling_params import GuidedDecodingParams
except ImportError:
GuidedDecodingParams = None # not available in vLLM 0.17+
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
prompts: list[dict[str, Any]] = []
@@ -362,7 +383,12 @@ def main(script_args: ScriptArguments):
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
# Use run_in_executor so blocking recv() doesn't freeze the event loop
# (allows /set_lora_adapter/ and other endpoints to be served concurrently)
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
all_outputs = [
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
]
@@ -404,7 +430,10 @@ def main(script_args: ScriptArguments):
}
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs))
@@ -474,11 +503,51 @@ def main(script_args: ScriptArguments):
)
return {"message": f"Batch update for {len(params_list)} params"}
class HTTPWeightUpdateRequest(BaseModel):
"""Weight update via HTTP (no NCCL needed)."""
params: list[
dict
] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
@app.post("/http_update_weights/")
async def http_update_weights(request: HTTPWeightUpdateRequest):
"""Update model weights via HTTP — no NCCL communicator required.
Tensor data is sent as base64-encoded raw bytes in the request body.
Slower than NCCL for large models but works without cross-process setup.
"""
from axolotl.utils.weight_serde import (
decode_from_http,
encode_for_ipc,
)
weights_to_load = [decode_from_http(p) for p in request.params]
# Send all weights in a single IPC call. Tensors don't survive
# vLLM's multiproc IPC, so serialize as raw bytes + metadata.
param_entries = [
encode_for_ipc(name, weight) for name, weight in weights_to_load
]
kwargs = {
"method": "http_load_weights_batch",
"kwargs": {"params": param_entries},
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"})
results = [conn.recv() for conn in connections]
loop = asyncio.get_running_loop()
results = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
return {"message": f"Reset prefix cache: {all(results)}"}
@app.post("/close_communicator/")

View File

@@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
model = self.model_runner.model
params_dict = dict(model.named_parameters())
# Handle VLM models where trainer and vLLM use different prefixes.
# Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..."
# vLLM (Qwen3.5): "language_model.model.layers.X..."
if name not in params_dict:
# Try common prefix remappings
for src_prefix, dst_prefix in [
("model.language_model.layers.", "language_model.model.layers."),
("model.layers.", "language_model.model.layers."),
]:
if name.startswith(src_prefix):
name = dst_prefix + name[len(src_prefix) :]
break
# Check if this is a simple direct param (exists as-is)
if name in params_dict:
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
@@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
return
# Fallback: try load_weights (may work for non-stacked params)
logger.warning("Falling back to load_weights for param: %s", name)
# Log the actual param names available for debugging
sample_keys = [
k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k
][:3]
logger.warning(
"Falling back to load_weights for param: %s (sample vLLM keys: %s)",
name,
sample_keys,
)
model.load_weights(weights=[(name, weight)])
def update_named_param(self, name, dtype, shape):
@@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
# Load weights using direct set (handles stacked params)
for name, weight in weights_to_load:
self._direct_set_weight(name, weight)
def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]):
"""Load weights received via HTTP (no NCCL needed)."""
for name, weight in weights:
self._direct_set_weight(name, weight.to(self.device))
def http_load_weight(self, **kwargs):
"""Load a single weight received via HTTP (no NCCL needed).
Reconstructs the tensor from raw bytes since tensors don't survive
vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights``
which handles TP sharding and stacked-param packing automatically.
"""
from axolotl.utils.weight_serde import decode_from_ipc
name, weight = decode_from_ipc(kwargs)
model = self.model_runner.model
model.load_weights(weights=[(name, weight)])
def http_load_weights_batch(self, params: list[dict]):
"""Load multiple weights in a single IPC call.
Uses vLLM's ``load_weights`` which handles TP sharding automatically.
"""
from axolotl.utils.weight_serde import decode_from_ipc
model = self.model_runner.model
weights = [decode_from_ipc(p) for p in params]
model.load_weights(weights=weights)

View File

@@ -138,7 +138,11 @@ def setup_reference_model(
model_ref = None # explicit setting to None
else:
reference_model: bool = True
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
trl_cfg = getattr(cfg, "trl", None)
if (
cfg.rl in {RLType.GRPO, RLType.EBFT}
and getattr(trl_cfg, "beta", 0) == 0
):
reference_model = False
# load the model again for model_ref/baseline
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
@@ -206,7 +210,7 @@ def execute_training(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT},
device_mesh=trainer.accelerator.torch_device_mesh,
)
)

View File

@@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
].append(pred_step_text)
row_index += 1
if logger == "wandb":
# type: ignore[attr-defined]
wandb.run.log(
wandb.run.log( # type: ignore[attr-defined]
{
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
table_data
@@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
artifact = wandb.Artifact( # type: ignore[attr-defined]
f"config-{wandb.run.id}", # type: ignore[attr-defined]
type="axolotl-config",
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
@@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush()
artifact = wandb.Artifact(
f"chat-template-{wandb.run.id}", type="jinja-template"
artifact = wandb.Artifact( # type: ignore[attr-defined]
f"chat-template-{wandb.run.id}", # type: ignore[attr-defined]
type="jinja-template",
)
artifact.add_file(temp_ct_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_ct_file.name)
wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_ct_file.name) # type: ignore[attr-defined]
LOG.info(
"The chat_template_jinja has been saved to the WandB run under files."
)
@@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
else:
skip_upload = True
if not skip_upload:
artifact = wandb.Artifact(
f"deepspeed-config-{wandb.run.id}",
artifact = wandb.Artifact( # type: ignore[attr-defined]
f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined]
type="deepspeed-config",
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
wandb.log_artifact(artifact) # type: ignore[attr-defined]
wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info(
"The DeepSpeed config has been saved to the WandB run under files."
)

View File

@@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback):
if not getattr(cfg, "generate_samples", False):
return
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
LOG.info(
f"Using eval dataloader for generation at step {state.global_step}"
)
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
LOG.info(
f"Using train dataloader for generation at step {state.global_step}"
f"Using eval dataloader for generation at step {state.global_step}"
)
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
LOG.info(
f"Using train dataloader for generation at step {state.global_step}"
)
self._log_samples(samples, state.global_step)
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
)
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples to console and W&B."""
@@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback):
try:
import wandb
if wandb.run is not None:
wandb.log(
if wandb.run is not None: # type: ignore[attr-defined]
wandb.log( # type: ignore[attr-defined]
{
f"samples/sample_{i + 1}": wandb.Html(
f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined]
f"<pre>{wandb_text}</pre>"
)
},

View File

@@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer
from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.lock import FileLockLoader
@@ -173,7 +174,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len
if rl in {RLType.GRPO, RLType.GDPO}:
if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}:
return True
raise ValueError("Unknown RL type")
@@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.EBFT:
ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
# Handle remove_columns: "__all__" removes all original columns,
# or filter a list to only columns that exist in the dataset
if "remove_columns" in map_kwargs:
ds_columns = (
dataset.column_names
if isinstance(dataset, Dataset)
else dataset[split].column_names
if isinstance(dataset, DatasetDict)
else []
)
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
)

View File

@@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__)
class EBFTConfig(BaseModel):
"""Configuration for Energy-Based Fine-Tuning (EBFT)"""
feature_layers: list[float] = Field(
default=[0.25, 0.5, 0.75],
json_schema_extra={
"description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])"
},
)
embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = (
Field(
default="last_token",
json_schema_extra={
"description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'"
},
)
)
use_whitening: bool = Field(
default=False,
json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
)
alignment_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for alignment reward (cosine similarity with ground truth)"
},
)
diversity_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for diversity penalty (pairwise similarity between samples)"
},
)
ce_coef: float = Field(
default=0.0,
json_schema_extra={
"description": "Cross-entropy loss coefficient on ground-truth tokens"
},
)
adaptive_max_tokens: bool = Field(
default=True,
json_schema_extra={
"description": "Set per-batch max_tokens based on ground-truth length"
},
)
gt_length_multiplier: float = Field(
default=1.5,
ge=0.1,
json_schema_extra={
"description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
},
)
# Strided mode fields (for unstructured text)
mode: Literal["structured", "strided"] = Field(
default="structured",
json_schema_extra={
"description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
},
)
stride: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Stride between anchor points (tokens)"},
)
context_length: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Context window size per block"},
)
generate_max_len: int = Field(
default=8,
ge=1,
json_schema_extra={"description": "Tokens to generate per block"},
)
n_samples_per_prompt: int = Field(
default=4,
ge=1,
json_schema_extra={"description": "Independent rollouts per document"},
)
temperature: float = Field(
default=0.6,
ge=0.0,
json_schema_extra={
"description": "Sampling temperature for strided generation"
},
)
top_p: float = Field(
default=1.0,
ge=0.0,
le=1.0,
json_schema_extra={"description": "Top-p nucleus sampling threshold"},
)
rl_coef: float = Field(
default=1.0,
json_schema_extra={"description": "RL policy gradient loss coefficient"},
)
advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
default="rloo",
json_schema_extra={
"description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
},
)
min_completion_prefix: int = Field(
default=0,
ge=0,
json_schema_extra={
"description": "Minimum tokens into completion before placing anchors. "
"Skips anchors too close to the prompt boundary where features are dominated by prompt context."
},
)
class AxolotlInputConfig(
ModelInputConfig,
ModelOutputConfig,
@@ -131,7 +244,7 @@ class AxolotlInputConfig(
rl: RLType | None = Field(
default=None,
json_schema_extra={
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'"
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'"
},
)
trl: TRLConfig | None = Field(
@@ -140,6 +253,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
ebft: EBFTConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(

View File

@@ -35,6 +35,7 @@ class RLType(str, Enum):
ORPO = "orpo"
KTO = "kto"
SIMPO = "simpo"
EBFT = "ebft"
class ChatTemplate(str, Enum):

View File

@@ -1,6 +1,6 @@
"""Pydantic models for TRL trainer configuration"""
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
@@ -133,6 +133,20 @@ class TRLConfig(BaseModel):
"description": "Penalty for tokens that appear in prompt and generated text."
},
)
generation_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional generation parameters passed to vLLM SamplingParams. "
"Useful for stop_token_ids, seed, frequency_penalty, etc."
},
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional kwargs for the chat template. "
"E.g., {enable_thinking: false} for Qwen3.5 models."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={

View File

@@ -1482,6 +1482,124 @@ class DistributedValidationMixin:
return self
class EBFTValidationMixin:
"""Validation for EBFT (Energy-Based Fine-Tuning) configuration."""
@model_validator(mode="before")
@classmethod
def check_ebft_config_required(cls, data):
"""rl: ebft requires an ebft config section."""
if data.get("rl") == "ebft" and not data.get("ebft"):
raise ValueError(
"`ebft` config section is required when `rl: ebft` is set. "
"Add an `ebft:` section with at least `mode: structured` or `mode: strided`."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_torch_compile(cls, data):
"""torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles
and CheckpointErrors. The flex_attention kernel compiles itself internally —
whole-model torch.compile is not needed and actively harmful."""
if (
data.get("rl") == "ebft"
and data.get("torch_compile") is True
and data.get("ebft", {}).get("mode") == "strided"
):
if data.get("gradient_checkpointing"):
raise ValueError(
"EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` "
"causes CheckpointError (BlockMask metadata mismatch during recomputation). "
"Remove `torch_compile` — the flex_attention kernel compiles itself internally."
)
LOG.warning(
"EBFT strided mode: `torch_compile: true` causes dynamo recompiles from "
"variable sequence lengths across steps. Consider removing it — "
"flex_attention compiles itself internally."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_gradient_checkpointing_reentrant(cls, data):
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("flex_attention")
and data.get("gradient_checkpointing")
):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
if not gc_kwargs.get("use_reentrant"):
LOG.warning(
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
)
if data.get("gradient_checkpointing_kwargs") is None:
data["gradient_checkpointing_kwargs"] = {}
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
return data
@model_validator(mode="before")
@classmethod
def check_ebft_activation_offloading(cls, data):
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
which conflicts with flex_attention's use_reentrant requirement."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("activation_offloading") is True
and data.get("flex_attention")
):
raise ValueError(
"EBFT strided mode: `activation_offloading: true` is incompatible with "
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
"uses micro-batched forward passes for memory efficiency instead."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_strided_sequence_len(cls, data):
"""Warn if sequence_len is too large for single-GPU strided EBFT."""
if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided":
return data
ebft = data.get("ebft", {})
seq_len = data.get("sequence_len", 512)
n_samples = ebft.get("n_samples_per_prompt", 4)
gen_len = ebft.get("generate_max_len", 8)
stride = ebft.get("stride", 8)
ctx_len = ebft.get("context_length", 8)
max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
full_seq = seq_len + max_blocks * gen_len
# Rough estimate: 8.7 GB per sample at S=3900 for 1B model
if full_seq * n_samples > 20000:
LOG.warning(
f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = "
f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM "
f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride."
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_strided_dataset_split(cls, data):
"""Warn about the common `train_on_split` mistake (silently ignored by schema)."""
datasets = data.get("datasets", [])
for ds in datasets or []:
if isinstance(ds, dict) and ds.get("train_on_split"):
LOG.warning(
f"Dataset has `train_on_split: {ds['train_on_split']}` — this field "
f"is not recognized and will be silently ignored. "
f"Use `split: {ds['train_on_split']}` instead."
)
return data
class GRPOVllmValidationMixin:
"""Validation mixin for vllm when using GRPO."""
@@ -1507,6 +1625,7 @@ class ValidationMixin(
PretrainingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
EBFTValidationMixin,
GRPOVllmValidationMixin,
):
"""Full validation mixin for Axolotl configuration."""

View File

@@ -57,6 +57,13 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Reasoning parser for VLLM"},
)
enforce_eager: bool | None = Field(
default=None,
json_schema_extra={
"description": "Disable CUDA graph capture in vLLM. Required for models with "
"causal_conv1d (e.g., Qwen3.5 hybrid linear attention)."
},
)
serve_module: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -0,0 +1,94 @@
"""Serialize / deserialize tensors for HTTP and IPC weight sync.
NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire
and reconstructed at the destination. All encode/decode helpers live here so
the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and
vllm_worker_ext.py.
"""
import base64
import torch
def encode_for_http(name: str, weight: torch.Tensor) -> dict:
"""Encode a named parameter for JSON transport over HTTP.
Returns a dict with keys: name, dtype (original), shape, data (base64).
bf16 tensors are sent as fp16 bytes; the original dtype is preserved in
the ``dtype`` field so the receiver can cast back.
"""
w_cpu = weight.contiguous().cpu()
orig_dtype = str(weight.dtype)
if w_cpu.dtype == torch.bfloat16:
w_cpu = w_cpu.half()
raw = w_cpu.numpy().tobytes()
return {
"name": name,
"dtype": orig_dtype,
"shape": list(weight.shape),
"data": base64.b64encode(raw).decode("ascii"),
}
def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]:
"""Decode an HTTP-encoded weight entry back to a named tensor.
Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the
original dtype stored in ``entry["dtype"]``.
"""
target_dtype = getattr(torch, entry["dtype"].split(".")[-1])
shape = tuple(entry["shape"])
raw = base64.b64decode(entry["data"])
n_elements = 1
for s in shape:
n_elements *= s
wire_bytes_per_elem = len(raw) // max(n_elements, 1)
if wire_bytes_per_elem == 2:
wire_dtype = torch.float16
elif wire_bytes_per_elem == 4:
wire_dtype = torch.float32
else:
wire_dtype = target_dtype
weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape)
if wire_dtype != target_dtype:
weight = weight.to(target_dtype)
return entry["name"], weight
def encode_for_ipc(name: str, weight: torch.Tensor) -> dict:
"""Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64).
Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype
(original), shape. bf16 tensors are serialized as fp16.
"""
w = weight.contiguous()
target_dtype = str(w.dtype).split(".")[-1]
if w.dtype == torch.bfloat16:
w = w.half()
wire_dtype = str(w.dtype).split(".")[-1]
return {
"name": name,
"data": w.numpy().tobytes(),
"dtype": wire_dtype,
"target_dtype": target_dtype,
"shape": list(weight.shape),
}
def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]:
"""Decode an IPC-encoded weight entry back to a named tensor.
Handles optional ``target_dtype`` for backward compatibility with older
serve code that may not include it.
"""
wire_dtype = getattr(torch, entry["dtype"])
weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape(
entry["shape"]
)
target_dtype = entry.get("target_dtype")
if target_dtype and target_dtype != entry["dtype"]:
weight = weight.to(getattr(torch, target_dtype))
return entry["name"], weight

294
tests/test_ebft_kernels.py Normal file
View File

@@ -0,0 +1,294 @@
"""Correctness tests for fused EBFT Triton kernels."""
import pytest
import torch
import torch.nn.functional as F
from axolotl.core.trainers.ebft.kernels import (
fused_cosine_similarity,
fused_diversity_penalty,
fused_log_softmax_gather,
fused_reinforce_loss,
)
# Skip all tests if CUDA not available
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
DEVICE = "cuda"
# ---------------------------------------------------------------------------
# 1. fused_log_softmax_gather
# ---------------------------------------------------------------------------
class TestFusedLogSoftmaxGather:
def _reference(self, logits, labels):
"""PyTorch reference: log_softmax + gather."""
lp = F.log_softmax(logits.float(), dim=-1)
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
def test_basic_correctness(self):
B, S, V = 2, 16, 1024
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_large_vocab(self):
"""Test with realistic vocab size (128K)."""
B, S, V = 1, 8, 128256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_fp32_input(self):
B, S, V = 2, 8, 512
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def test_output_is_negative(self):
"""log_softmax values should always be <= 0."""
B, S, V = 4, 32, 2048
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
out = fused_log_softmax_gather(logits, labels)
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
def test_extreme_logits(self):
"""Test numerical stability with very large/small logits."""
B, S, V = 1, 4, 256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
logits[:, 0, :] = 1000.0 # very large
logits[:, 1, :] = -1000.0 # very small
logits[:, 2, 0] = 1000.0 # one hot-ish
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
assert torch.isfinite(out).all(), "Should handle extreme values"
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with pre-flattened (N, V) input."""
N, V = 64, 4096
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (N,), device=DEVICE)
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 2. fused_reinforce_loss
# ---------------------------------------------------------------------------
class TestFusedReinforceLoss:
def _reference(self, logps, advantages, mask):
"""PyTorch reference implementation."""
loss_per_token = -logps * advantages
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
def test_basic_correctness(self):
N = 1024
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with (B, S) shaped inputs."""
B, S = 4, 256
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_all_masked(self):
"""All-zero mask should return 0."""
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
out = fused_reinforce_loss(logps, advs, mask)
assert out.item() == 0.0
def test_all_unmasked(self):
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic size (4 * 3000 tokens)."""
N = 12000
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 3. fused_cosine_similarity
# ---------------------------------------------------------------------------
class TestFusedCosineSimilarity:
def test_basic_correctness(self):
N, D = 64, 256
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_batched(self):
"""Test with (B, N, NB, D) shaped input."""
B, N, NB, D = 2, 4, 16, 512
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_identical_vectors(self):
"""Identical vectors should give similarity = 1."""
N, D = 16, 128
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, a)
torch.testing.assert_close(
out,
torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_orthogonal_vectors(self):
"""Orthogonal vectors should give similarity = 0."""
D = 128
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
a[0, 0] = 1.0
b[0, 1] = 1.0
out = fused_cosine_similarity(a, b)
assert abs(out.item()) < 1e-5
def test_opposite_vectors(self):
"""Opposite vectors should give similarity = -1."""
N, D = 8, 64
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, -a)
torch.testing.assert_close(
out,
-torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_large_dimension(self):
"""Test with large feature dimension (multi-layer concatenated features)."""
N, D = 32, 4608 # 3 layers * 1536 hidden
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
# ---------------------------------------------------------------------------
# 4. fused_diversity_penalty
# ---------------------------------------------------------------------------
class TestFusedDiversityPenalty:
def _reference(self, embeddings):
"""PyTorch reference: bmm + mask diagonal + mean."""
B, N, D = embeddings.shape
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
return sims.sum(dim=-1) / (N - 1)
def test_basic_correctness(self):
B, N, D = 4, 4, 256
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_two_samples(self):
"""With n=2, diversity = dot(a, b) for each."""
B, D = 3, 128
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_identical_samples(self):
"""All identical samples should give max diversity."""
B, N, D = 2, 4, 64
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
emb = vec.expand(B, N, D).contiguous()
out = fused_diversity_penalty(emb)
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic EBFT dimensions."""
B, N, D = 1, 4, 4608 # multi-layer features
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
def test_single_sample_returns_zeros(self):
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
B, D = 3, 128
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
out = fused_diversity_penalty(emb)
assert (out == 0).all(), "N=1 diversity should be exactly zero"

View File

@@ -0,0 +1,363 @@
"""Tests for the EBFT strided structured dataset transform and data loading."""
import pytest
from datasets import Dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
@pytest.fixture
def tokenizer():
"""Create a simple word-level tokenizer — no network access needed."""
# Build a tiny vocab covering common test words
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
words = (
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
"x write code print test some string metadata noise ok good python "
"sampling abc 123 def solve return this that"
).split()
for w in words:
if w not in vocab:
vocab[w] = len(vocab)
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
backend.pre_tokenizer = pre_tokenizers.Whitespace()
tok = PreTrainedTokenizerFast(
tokenizer_object=backend,
bos_token="[BOS]",
eos_token="[EOS]",
pad_token="[PAD]",
unk_token="[UNK]",
)
return tok
@pytest.fixture
def cfg():
return DictDefault({"sequence_len": 64})
@pytest.fixture
def transform_fn_and_kwargs(cfg):
result = load_ebft("ebft_strided_structured.transform", cfg)
assert result is not None, "Failed to load ebft_strided_structured transform"
transform_fn, map_kwargs = result
return transform_fn, map_kwargs
class TestEBFTStridedStructuredTransform:
"""Tests for the dataset transform function itself."""
def test_transform_loads(self, transform_fn_and_kwargs):
transform_fn, map_kwargs = transform_fn_and_kwargs
assert callable(transform_fn)
assert "remove_columns" in map_kwargs
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
"""Transform should signal removal of all original columns."""
_, map_kwargs = transform_fn_and_kwargs
assert map_kwargs["remove_columns"] == "__all__"
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
"""Prompt tokens get labels=-100, completion tokens get real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
result = transform_fn(example, tokenizer=tokenizer)
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "prompt_length" in result
prompt_length = result["prompt_length"]
labels = result["labels"]
seq_len = len(result["input_ids"])
assert seq_len == 64, "Should be padded to sequence_len"
assert len(labels) == seq_len
assert prompt_length > 0
# Prompt tokens should be masked
for i in range(prompt_length):
assert labels[i] == -100, f"Prompt token at {i} should be -100"
# At least one completion token should have a real label
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
"""prompt_length should be the exact boundary between -100 and real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hello world", "output": "goodbye world"}
result = transform_fn(example, tokenizer=tokenizer)
prompt_length = result["prompt_length"]
labels = result["labels"]
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
assert labels[prompt_length] != -100, (
"First completion token should not be masked"
)
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
"""Padding tokens should have labels=-100 and attention_mask=0."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hi", "output": "bye"}
result = transform_fn(example, tokenizer=tokenizer)
attention_mask = result["attention_mask"]
labels = result["labels"]
real_len = sum(attention_mask)
assert real_len < 64, "Short example should have padding"
for i in range(real_len, 64):
assert attention_mask[i] == 0, (
f"Pad position {i} should have attention_mask=0"
)
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
"""Long completions should be truncated to fit sequence_len."""
transform_fn, _ = transform_fn_and_kwargs
example = {
"input": "short prompt",
"output": "x " * 500,
}
result = transform_fn(example, tokenizer=tokenizer)
assert len(result["input_ids"]) == 64
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
"""Transform should handle different field name conventions."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn(
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
result = transform_fn(
{"question": "what", "answer": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
"""Without tokenizer, should return a prompt string."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn({"input": "hello", "output": "world"})
assert "prompt" in result
assert result["prompt"] == "hello"
class TestEBFTColumnRemoval:
"""Tests for the __all__ column removal logic in the RL data path."""
def _filter_remove_columns(self, map_kwargs, dataset):
"""Reproduce the filtering logic from rl.py _load_split."""
if "remove_columns" in map_kwargs:
ds_columns = dataset.column_names
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
return map_kwargs
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""After mapping, only tokenized columns should remain."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs) # copy
ds = Dataset.from_list(
[
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert "input" in map_kwargs["remove_columns"]
assert "output" in map_kwargs["remove_columns"]
assert "extra_field" in map_kwargs["remove_columns"]
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
assert "input_ids" in mapped.column_names
assert "labels" in mapped.column_names
assert "prompt_length" in mapped.column_names
assert "input" not in mapped.column_names
assert "output" not in mapped.column_names
assert "extra_field" not in mapped.column_names
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""Datasets with many extra metadata columns should all be cleaned up."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{
"input": "write hello world",
"output": "print hello",
"id": "abc 123",
"domain": "python",
"generation_algorithm": "sampling",
"llm_judgement": "good",
"unit_tests": "test",
"tests_execution_status": "ok",
"average_test_score": 0.95,
},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert len(map_kwargs["remove_columns"]) == 9
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
assert set(mapped.column_names) == expected_columns
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
"""No string-typed columns should remain (would crash the DataLoader)."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{"input": "test", "output": "test", "notes": "some string metadata"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
for col in mapped.column_names:
feat = mapped.features[col]
assert str(feat) != "string", (
f"Column '{col}' is still a string — would crash DataLoader"
)
def test_filter_preserves_explicit_list(self):
"""When remove_columns is an explicit list, only existing columns are kept."""
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
ds_columns = ds.column_names
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
assert map_kwargs["remove_columns"] == ["a", "b"]
assert "missing_col" not in map_kwargs["remove_columns"]
class TestMultiTurnSeparators:
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
def test_multiturn_transform_splits_turns(self):
"""Transform should store first turn as GT and remaining turns separately."""
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"sequence_len": 512})
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
out = fn(
{
"messages": [
{"role": "user", "content": "Q1"},
{"role": "assistant", "content": "A1"},
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
}
)
# ground_truth is only the first assistant turn
assert out["ground_truth"] == "A1"
# remaining_turns carries the rest for trainer-side reconstruction
assert out["remaining_turns"] == [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
def test_multiturn_gt_reconstruction_via_chat_template(self):
"""Trainer-side GT reconstruction should insert role markers between turns.
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
GT using apply_chat_template, ensuring assistant turns are separated by
role markers rather than concatenated as raw text.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
# Simulate the transform output
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
remaining_turns = [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
gt_conv.extend(remaining_turns)
full_gt_text = tokenizer.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
# The full GT text should contain both assistant turns with role markers
assert "A1" in full_gt_text
assert "A2" in full_gt_text
# Raw concatenation "A1A2" should NOT appear — role markers separate them
assert "A1A2" not in full_gt_text, (
"GT reconstruction should have role markers between turns, not raw concatenation"
)
# The user turn Q2 should appear between A1 and A2
a1_pos = full_gt_text.index("A1")
a2_pos = full_gt_text.index("A2")
q2_pos = full_gt_text.index("Q2")
assert a1_pos < q2_pos < a2_pos, (
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
)
# The GT should start with the prompt
assert full_gt_text.startswith(prompt_text), (
"Full GT should start with the rendered prompt"
)
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
# remaining_turns would be [] for single-turn prompts
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
# With empty remaining_turns, trainer falls through to raw concat
# (trainer.py:302: gt_texts.append(prompt_text + gt))
gt_text = prompt_text + gt
assert gt_text.endswith("A1")
assert prompt_text in gt_text

View File

@@ -0,0 +1,158 @@
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
"""
import pytest
import torch
from axolotl.utils.weight_serde import (
decode_from_http,
decode_from_ipc,
encode_for_http,
encode_for_ipc,
)
# ---------------------------------------------------------------------------
# Stage 1: trainer → serve endpoint (HTTP with base64)
# ---------------------------------------------------------------------------
class TestHttpEncodeRoundTrip:
"""Test encode_for_http / decode_from_http."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_http("layer.weight", original)
name, decoded = decode_from_http(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
# bf16→fp16→bf16 loses some precision
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_wire_format_is_fp16(self):
"""bf16 tensors should be sent as fp16 on the wire."""
import base64
original = torch.randn(8, 16, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
raw = base64.b64decode(entry["data"])
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
assert len(raw) == 8 * 16 * 2
# dtype field should preserve original dtype for reconstruction
assert entry["dtype"] == "torch.bfloat16"
def test_multidimensional_shapes(self):
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
original = torch.randn(*shape, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
_, decoded = decode_from_http(entry)
assert decoded.shape == original.shape
assert decoded.dtype == torch.bfloat16
# ---------------------------------------------------------------------------
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
# ---------------------------------------------------------------------------
class TestIpcEncodeRoundTrip:
"""Test encode_for_ipc / decode_from_ipc."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_ipc("layer.weight", original)
name, decoded = decode_from_ipc(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_ipc_wire_is_fp16(self):
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
original = torch.randn(4, 8, dtype=torch.bfloat16)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float16"
assert entry["target_dtype"] == "bfloat16"
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
def test_fp32_has_no_target_dtype_mismatch(self):
original = torch.randn(4, 8, dtype=torch.float32)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float32"
assert entry["target_dtype"] == "float32"
def test_worker_handles_missing_target_dtype(self):
"""Backward compat: older serve code may not send target_dtype."""
entry = {
"name": "w",
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
"dtype": "float32",
"shape": [4, 8],
# no target_dtype key
}
name, decoded = decode_from_ipc(entry)
assert decoded.dtype == torch.float32
assert decoded.shape == (4, 8)
# ---------------------------------------------------------------------------
# Full pipeline: trainer → serve → worker
# ---------------------------------------------------------------------------
class TestFullPipelineRoundTrip:
"""End-to-end: trainer → serve → worker."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_three_stage_round_trip(self, dtype):
"""Tensor survives trainer→serve→worker with correct dtype and values."""
original = torch.randn(16, 32, dtype=dtype)
# Stage 1: trainer encodes for HTTP
http_entry = encode_for_http("model.layers.0.weight", original)
# Stage 2: serve decodes HTTP, re-encodes for IPC
name, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc(name, at_serve)
# Stage 3: worker decodes IPC
_, at_worker = decode_from_ipc(ipc_entry)
assert at_worker.dtype == dtype
assert at_worker.shape == original.shape
if dtype == torch.bfloat16:
# Two bf16→fp16→bf16 hops compound precision loss slightly
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
else:
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
def test_bfloat16_precision_loss_is_bounded(self):
"""bf16→fp16→bf16 round-trip error should be small."""
original = torch.randn(256, 256, dtype=torch.bfloat16)
http_entry = encode_for_http("w", original)
_, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc("w", at_serve)
_, at_worker = decode_from_ipc(ipc_entry)
max_err = (at_worker.float() - original.float()).abs().max().item()
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
def test_bfloat16_numpy_would_crash_without_fix(self):
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
t = torch.randn(4, 4, dtype=torch.bfloat16)
with pytest.raises((RuntimeError, TypeError)):
t.numpy()