diff --git a/docker/Dockerfile-cloud-uv b/docker/Dockerfile-cloud-uv
index a53dd6135..2facb6fa7 100644
--- a/docker/Dockerfile-cloud-uv
+++ b/docker/Dockerfile-cloud-uv
@@ -22,6 +22,7 @@ RUN apt update && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
+ printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
diff --git a/examples/ebft/README.md b/examples/ebft/README.md
new file mode 100644
index 000000000..533e13652
--- /dev/null
+++ b/examples/ebft/README.md
@@ -0,0 +1,214 @@
+# Energy-Based Fine-Tuning (EBFT)
+
+EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
+
+## Overview
+
+EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
+
+**Key advantages over SFT:**
+- Operates on model rollouts (not teacher forcing), reducing distribution shift
+- Provides dense sequence-level supervision without a task-specific verifier
+- Improves both downstream accuracy and validation cross-entropy simultaneously
+
+**Key advantages over RLVR:**
+- No reward model or verifier required — works on any (prompt, completion) data
+- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
+- Maintains distributional calibration (low feature-matching loss)
+
+## Two Modes
+
+EBFT supports two modes depending on your data format:
+
+### Structured Mode (`mode: structured`, default)
+For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
+- Extends GRPOTrainer — uses vLLM for fast rollout generation
+- RLOO advantages and clipped policy gradient from GRPO
+- Feature-matching rewards replace external reward functions
+
+### Strided Mode (`mode: strided`)
+For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
+- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
+- No vLLM needed — generation uses custom strided attention masks
+- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
+- Compatible with gradient checkpointing via automatic dtype normalization
+- This is the core EBFT algorithm from the paper (Section F)
+
+### Common to both modes:
+- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
+- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
+- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
+- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
+- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
+- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
+
+## Quick Start
+
+### Structured Mode (QA data + vLLM)
+
+```bash
+# 1. Start vLLM server
+python -m trl.scripts.vllm_serve \
+ --model meta-llama/Llama-3.2-1B \
+ --host 0.0.0.0 --port 8000 \
+ --gpu-memory-utilization 0.3
+
+# 2. Train
+axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
+```
+
+### Strided Mode (unstructured text)
+
+```bash
+# No vLLM needed — strided generation is built-in
+axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
+```
+
+## Configuration
+
+### Common EBFT Settings
+
+```yaml
+rl: ebft
+
+ebft:
+ # Feature network: which layers to extract hidden states from
+ # Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
+ feature_layers: [0.25, 0.5, 0.75]
+
+ # How to pool per-token hidden states into sequence embeddings
+ # Options: "last_token" (recommended), "mean_pooling", "concat"
+ embed_method: last_token
+
+ # SVD whitening — strongly recommended (paper shows largest degradation without it)
+ use_whitening: true
+
+ # Reward = alignment_coef * alignment - diversity_coef * diversity
+ # Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
+ # diversity uses raw dot product — both are bounded after whitening.
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+
+ # Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
+ # 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
+ ce_coef: 0.0
+```
+
+### Strided Mode Settings
+
+```yaml
+ebft:
+ mode: strided
+ stride: 8 # tokens between anchor points (paper default: 8)
+ context_length: 8 # context window per block (paper default: 8)
+ generate_max_len: 8 # tokens generated per block (paper default: 8)
+ n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
+ temperature: 0.6
+ rl_coef: 1.0 # RL loss weight
+ advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
+```
+
+### Structured Mode Settings (via TRL)
+
+```yaml
+trl:
+ num_generations: 4 # samples per prompt
+ max_completion_length: 256 # max tokens to generate
+ temperature: 1.0
+ use_vllm: true
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+```
+
+### Dataset Format
+
+**Structured mode** — QA data with prompt + ground-truth completion:
+```yaml
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+```
+Transform returns: `{"prompt": ..., "ground_truth": ...}`
+
+**Strided mode** — raw text tokenized to fixed length:
+```yaml
+datasets:
+ - path: sjelassi/swallow_code_20m
+ type: ebft_pretrain.transform
+```
+Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
+
+## How It Works
+
+### Structured Mode
+1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
+2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
+3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
+4. **RLOO advantages**: subtract leave-one-out group mean
+5. **Policy gradient**: clipped PPO-style loss
+
+### Strided Mode
+1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
+2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
+3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
+4. **Per-block rewards**:
+ - **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
+ - **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
+ - **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
+5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
+6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
+
+### Tracked Metrics
+
+| Metric | Description |
+|--------|-------------|
+| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
+| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
+| `ebft/mean_reward` | alignment - diversity (should trend upward) |
+| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
+| `ebft/rl_loss` | REINFORCE policy gradient loss |
+| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
+| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
+
+## Tips and Recommendations
+
+### Reward coefficients
+- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
+- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
+- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
+- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
+
+### Feature extraction
+- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
+- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
+
+### Performance
+- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
+- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
+
+### Memory
+- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
+- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
+- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
+- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
+
+## Example Configs
+
+| Config | Mode | Model | Description |
+|--------|------|-------|-------------|
+| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
+| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
+| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
+| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
+
+## Citation
+
+```bibtex
+@article{jelassi2026matching,
+ title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
+ author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
+ journal={arXiv preprint arXiv:2603.12248},
+ year={2026}
+}
+```
diff --git a/examples/ebft/ebft_opencode.py b/examples/ebft/ebft_opencode.py
new file mode 100644
index 000000000..677949ba5
--- /dev/null
+++ b/examples/ebft/ebft_opencode.py
@@ -0,0 +1,28 @@
+"""
+Dataset transform for nvidia/OpenCodeInstruct with EBFT.
+
+Maps the dataset's `input` (prompt) and `output` (code solution) fields
+to the format expected by the EBFT trainer.
+"""
+
+
+def transform(cfg, *args, **kwargs):
+ def transform_fn(example, tokenizer=None):
+ return {
+ "prompt": [
+ {"role": "user", "content": example["input"]},
+ ],
+ "ground_truth": example["output"],
+ }
+
+ return transform_fn, {
+ "remove_columns": [
+ "id",
+ "domain",
+ "generation_algorithm",
+ "llm_judgement",
+ "unit_tests",
+ "tests_execution_status",
+ "average_test_score",
+ ]
+ }
diff --git a/examples/ebft/ebft_pretrain.py b/examples/ebft/ebft_pretrain.py
new file mode 100644
index 000000000..27a1e54b9
--- /dev/null
+++ b/examples/ebft/ebft_pretrain.py
@@ -0,0 +1,31 @@
+"""
+Dataset transform for unstructured text data with strided EBFT.
+
+Tokenizes raw text into fixed-length input_ids for the strided trainer.
+Sequences are padded to sequence_len for uniform batching.
+"""
+
+
+def transform(cfg, *args, **kwargs):
+ seq_len = cfg.sequence_len
+
+ def transform_fn(example, tokenizer=None):
+ text = example.get("question", example.get("text", ""))
+ if tokenizer is None:
+ return {"prompt": text}
+
+ encoded = tokenizer(
+ text,
+ truncation=True,
+ max_length=seq_len,
+ padding="max_length",
+ add_special_tokens=True,
+ return_tensors=None,
+ )
+ return {
+ "input_ids": encoded["input_ids"],
+ "attention_mask": encoded["attention_mask"],
+ "labels": list(encoded["input_ids"]),
+ }
+
+ return transform_fn, {"remove_columns": ["question", "answer"]}
diff --git a/examples/ebft/ebft_strided_structured.py b/examples/ebft/ebft_strided_structured.py
new file mode 100644
index 000000000..48743931a
--- /dev/null
+++ b/examples/ebft/ebft_strided_structured.py
@@ -0,0 +1,80 @@
+"""
+Dataset transform for structured (prompt, completion) data with strided EBFT.
+
+Tokenizes prompt and completion separately, concatenates into a single
+input_ids sequence, and marks prompt tokens with labels=-100 so the
+strided trainer knows where to place anchors (completion span only).
+
+Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
+"""
+
+
+def transform(cfg, *args, **kwargs):
+ seq_len = cfg.sequence_len
+
+ def transform_fn(example, tokenizer=None):
+ # Extract prompt and completion from the example
+ prompt_text = example.get(
+ "input", example.get("prompt", example.get("question", ""))
+ )
+ completion_text = example.get(
+ "output", example.get("completion", example.get("answer", ""))
+ )
+
+ if tokenizer is None:
+ return {"prompt": prompt_text}
+
+ pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
+
+ # Tokenize prompt and completion separately
+ prompt_enc = tokenizer(
+ prompt_text,
+ truncation=False,
+ add_special_tokens=True,
+ return_tensors=None,
+ )
+ completion_enc = tokenizer(
+ completion_text,
+ truncation=False,
+ add_special_tokens=False,
+ return_tensors=None,
+ )
+
+ prompt_ids = prompt_enc["input_ids"]
+ completion_ids = completion_enc["input_ids"]
+
+ # Truncate to fit within seq_len (prioritize keeping prompt + some completion)
+ total_len = len(prompt_ids) + len(completion_ids)
+ if total_len > seq_len:
+ # Truncate completion first, then prompt if needed
+ max_completion = seq_len - len(prompt_ids)
+ if max_completion < 1:
+ # Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
+ prompt_ids = prompt_ids[: seq_len - 1]
+ completion_ids = completion_ids[:1]
+ else:
+ completion_ids = completion_ids[:max_completion]
+
+ input_ids = prompt_ids + completion_ids
+ prompt_length = len(prompt_ids)
+
+ # Labels: -100 for prompt tokens, input_ids for completion tokens
+ labels = [-100] * prompt_length + completion_ids
+
+ # Pad to seq_len
+ pad_len = seq_len - len(input_ids)
+ attention_mask = [1] * len(input_ids) + [0] * pad_len
+ labels = labels + [-100] * pad_len
+ input_ids = input_ids + [pad_id] * pad_len
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "prompt_length": prompt_length,
+ }
+
+ # Signal to remove all original columns (filtered to existing ones at map time)
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
diff --git a/examples/ebft/llama-1b-ebft-opencode-novllm.yaml b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml
new file mode 100644
index 000000000..0891033f0
--- /dev/null
+++ b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml
@@ -0,0 +1,64 @@
+# EBFT validation config — no vLLM, uses HF generate for simplicity
+# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
+
+base_model: meta-llama/Llama-3.2-1B
+chat_template: llama3
+rl: ebft
+
+ebft:
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: false
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ ce_coef: 0.0
+
+trl:
+ num_generations: 4
+ max_completion_length: 128
+ temperature: 1.0
+ use_vllm: false
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+ split: train[:1%]
+
+sequence_len: 512
+micro_batch_size: 2
+gradient_accumulation_steps: 2
+num_epochs: 1
+max_steps: 10
+
+learning_rate: 1.0e-5
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 2
+weight_decay: 0.01
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_linear: true
+
+bf16: auto
+flash_attention: true
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-validation
+
+wandb_project: ebft
+wandb_run_id:
+wandb_watch:
+wandb_log_model:
+
+logging_steps: 1
+save_steps: 100
diff --git a/examples/ebft/llama-1b-ebft-opencode.yaml b/examples/ebft/llama-1b-ebft-opencode.yaml
new file mode 100644
index 000000000..d0d1069d8
--- /dev/null
+++ b/examples/ebft/llama-1b-ebft-opencode.yaml
@@ -0,0 +1,81 @@
+# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
+#
+# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
+# https://arxiv.org/abs/2603.12248
+#
+# Prerequisites:
+# 1. Start vLLM server on a separate GPU:
+# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
+# --model meta-llama/Llama-3.2-1B \
+# --host 0.0.0.0 --port 8000 \
+# --gpu-memory-utilization 0.4 --dtype bfloat16
+#
+# 2. Run training:
+# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
+
+base_model: meta-llama/Llama-3.2-1B
+chat_template: llama3
+
+# --- Training method ---
+rl: ebft
+
+# --- EBFT configuration ---
+ebft:
+ feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
+ embed_method: last_token # pool to sequence embedding via last token
+ use_whitening: false # SVD whitening (disable for speed in small runs)
+ alignment_coef: 1.0 # cosine similarity with ground-truth features
+ diversity_coef: 1.0 # pairwise similarity penalty
+ ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
+
+# --- Generation settings (via TRL/GRPO infrastructure) ---
+trl:
+ num_generations: 4 # samples per prompt for RLOO
+ max_completion_length: 256 # max generated tokens
+ temperature: 1.0
+ use_vllm: true
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+
+# --- Dataset ---
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+ split: train[:1%] # first 1% for validation runs
+
+# --- Training hyperparameters ---
+sequence_len: 1024
+micro_batch_size: 2
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 50
+
+learning_rate: 1.0e-5
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 5
+weight_decay: 0.01
+
+# --- LoRA (recommended to reduce memory with frozen feature network) ---
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_linear: true
+
+# --- Hardware ---
+bf16: auto
+flash_attention: true
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-llama-1b-opencode
+
+# --- Logging ---
+use_tensorboard: true
+logging_steps: 1
+save_steps: 25
diff --git a/examples/ebft/llama-1b-ebft-strided-structured.yaml b/examples/ebft/llama-1b-ebft-strided-structured.yaml
new file mode 100644
index 000000000..8ba63b64b
--- /dev/null
+++ b/examples/ebft/llama-1b-ebft-strided-structured.yaml
@@ -0,0 +1,65 @@
+# EBFT Strided Structured Mode: For structured (prompt, completion) data
+# Uses strided block-parallel generation on completion spans — no vLLM needed.
+#
+# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
+
+base_model: meta-llama/Llama-3.2-1B
+rl: ebft
+
+ebft:
+ mode: strided # strided block-parallel generation
+ stride: 8 # tokens between anchor points
+ context_length: 8 # context window per block
+ generate_max_len: 8 # tokens to generate per block
+ n_samples_per_prompt: 4 # rollouts per document
+ temperature: 0.6
+ top_p: 1.0
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: true
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ rl_coef: 1.0
+ ce_coef: 0.03 # small CE weight for structured data
+ advantage_estimator: rloo
+ min_completion_prefix: 8 # skip anchors too close to prompt boundary
+
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_strided_structured.transform
+ split: train[:1%]
+
+sequence_len: 2048
+micro_batch_size: 1
+gradient_accumulation_steps: 2
+num_epochs: 1
+# max_steps: 10
+
+learning_rate: 1.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 5
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_linear: true
+
+bf16: auto
+flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
+flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
+ # (full-model compile conflicts with gradient checkpointing + flex_attention)
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-strided-structured
+
+wandb_project: ebft
+logging_steps: 1
+save_steps: 100
diff --git a/examples/ebft/llama-1b-ebft-strided.yaml b/examples/ebft/llama-1b-ebft-strided.yaml
new file mode 100644
index 000000000..c9519f160
--- /dev/null
+++ b/examples/ebft/llama-1b-ebft-strided.yaml
@@ -0,0 +1,60 @@
+# EBFT Strided Mode: For unstructured text data (raw code, prose)
+# Uses strided block-parallel generation — no vLLM needed.
+#
+# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
+
+base_model: meta-llama/Llama-3.2-1B
+rl: ebft
+
+ebft:
+ mode: strided # strided block-parallel generation
+ stride: 8 # tokens between anchor points
+ context_length: 8 # context window per block
+ generate_max_len: 8 # tokens to generate per block
+ n_samples_per_prompt: 4 # rollouts per document
+ temperature: 0.6
+ top_p: 1.0
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: true
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ rl_coef: 1.0
+ ce_coef: 0.0
+ advantage_estimator: rloo
+
+datasets:
+ - path: sjelassi/swallow_code_20m
+ type: ebft_pretrain.transform
+ split: train[:100]
+
+sequence_len: 256
+micro_batch_size: 1
+gradient_accumulation_steps: 2
+num_epochs: 1
+max_steps: 5
+
+learning_rate: 1.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 2
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_linear: true
+
+bf16: auto
+flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-strided-validation
+
+wandb_project: ebft
+logging_steps: 1
+save_steps: 100
diff --git a/examples/ebft/llama-3b-ebft-strided-fft.yaml b/examples/ebft/llama-3b-ebft-strided-fft.yaml
new file mode 100644
index 000000000..5695efa40
--- /dev/null
+++ b/examples/ebft/llama-3b-ebft-strided-fft.yaml
@@ -0,0 +1,69 @@
+# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
+# Actor on GPU 0, frozen feature network on GPU 1
+#
+# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
+
+base_model: meta-llama/Llama-3.2-3B
+rl: ebft
+
+ebft:
+ mode: strided
+ stride: 8
+ context_length: 8
+ generate_max_len: 8
+ n_samples_per_prompt: 4
+ temperature: 0.6
+ top_p: 1.0
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: true
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ rl_coef: 1.0
+ ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
+ advantage_estimator: rloo
+
+datasets:
+ - path: sjelassi/swallow_code_20m
+ type: ebft_pretrain.transform
+ split: train[:5000]
+
+sequence_len: 1024
+micro_batch_size: 1
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 100
+
+learning_rate: 1.0e-5
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 10
+weight_decay: 0.01
+
+adapter: lora
+lora_r: 32
+lora_alpha: 64
+lora_dropout: 0.05
+lora_target_linear: true
+
+bf16: auto
+torch_dtype: bfloat16
+flash_attention: false
+gradient_checkpointing: true
+torch_compile: true
+gradient_checkpointing_kwargs:
+ use_reentrant: true
+ddp: false
+device_map:
+ "": 0
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-llama3b-strided
+
+wandb_project: ebft
+wandb_name: llama3b-strided-lora-100steps
+logging_steps: 1
+save_steps: 50
diff --git a/examples/ebft/llama-8b-ebft-strided-fft.yaml b/examples/ebft/llama-8b-ebft-strided-fft.yaml
new file mode 100644
index 000000000..8cf962849
--- /dev/null
+++ b/examples/ebft/llama-8b-ebft-strided-fft.yaml
@@ -0,0 +1,58 @@
+# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
+# Feature network is CPU-offloaded to fit in single 32GB GPU
+#
+# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
+
+base_model: meta-llama/Llama-3.1-8B
+rl: ebft
+
+ebft:
+ mode: strided
+ stride: 8
+ context_length: 8
+ generate_max_len: 8
+ n_samples_per_prompt: 4
+ temperature: 0.6
+ top_p: 1.0
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: true
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ rl_coef: 1.0
+ ce_coef: 0.0
+ advantage_estimator: rloo
+
+datasets:
+ - path: sjelassi/swallow_code_20m
+ type: ebft_pretrain.transform
+ split: train[:5000]
+
+sequence_len: 1024
+micro_batch_size: 1
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 100
+
+learning_rate: 1.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 10
+weight_decay: 0.01
+
+bf16: auto
+flash_attention: false # strided EBFT uses flex_attention at runtime
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+
+special_tokens:
+ pad_token: "<|end_of_text|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-llama8b-strided
+
+wandb_project: ebft
+wandb_name: llama8b-strided-fft-100steps
+logging_steps: 1
+save_steps: 50
diff --git a/examples/ebft/qwen35-4b-ebft-structured-async.yaml b/examples/ebft/qwen35-4b-ebft-structured-async.yaml
new file mode 100644
index 000000000..759a31730
--- /dev/null
+++ b/examples/ebft/qwen35-4b-ebft-structured-async.yaml
@@ -0,0 +1,86 @@
+# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
+#
+# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
+# full attention every 4th layer. This tests EBFT compatibility.
+#
+# Prerequisites:
+# 1. Start vLLM on GPU 0:
+# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
+#
+# 2. Run training on GPU 1:
+# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
+# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
+
+base_model: Qwen/Qwen3.5-4B
+
+rl: ebft
+
+ebft:
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: false
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ ce_coef: 0.0
+
+trl:
+ num_generations: 4
+ max_completion_length: 256
+ temperature: 0.7
+ use_vllm: true
+ vllm_server_host: 0.0.0.0
+ vllm_server_port: 8000
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+ generation_kwargs:
+ stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
+ chat_template_kwargs:
+ enable_thinking: false
+ async_prefetch: true
+ vllm_server_timeout: 300
+
+vllm:
+ gpu_memory_utilization: 0.5
+ max_model_len: 2048
+ serve_module: axolotl.scripts.vllm_serve_lora
+ enforce_eager: true
+
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+ split: train[:500]
+
+sequence_len: 1024
+micro_batch_size: 1
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 10
+
+learning_rate: 5.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 3
+weight_decay: 0.01
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.0
+# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
+# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
+lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
+
+bf16: auto
+flash_attention: true
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|endoftext|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-qwen35-4b-structured-async
+
+wandb_project: ebft
+logging_steps: 1
+save_steps: 50
diff --git a/examples/ebft/qwen35-4b-ebft-structured.yaml b/examples/ebft/qwen35-4b-ebft-structured.yaml
new file mode 100644
index 000000000..9108e87e9
--- /dev/null
+++ b/examples/ebft/qwen35-4b-ebft-structured.yaml
@@ -0,0 +1,77 @@
+# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
+#
+# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
+# full attention every 4th layer. This tests EBFT compatibility.
+#
+# Prerequisites:
+# 1. Start vLLM on GPU 0:
+# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
+# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
+#
+# 2. Run training on GPU 1:
+# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
+# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
+
+base_model: Qwen/Qwen3.5-4B
+
+rl: ebft
+
+ebft:
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: false
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ ce_coef: 0.0
+
+trl:
+ num_generations: 4
+ max_completion_length: 256
+ temperature: 0.7
+ use_vllm: true
+ vllm_server_host: 0.0.0.0
+ vllm_server_port: 8000
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+ generation_kwargs:
+ stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
+ chat_template_kwargs:
+ enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
+
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+ split: train[:500]
+
+sequence_len: 1024
+micro_batch_size: 1
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 10
+
+learning_rate: 5.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 3
+weight_decay: 0.01
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.0
+lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
+
+bf16: auto
+flash_attention: true
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|endoftext|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-qwen35-4b-structured
+
+wandb_project: ebft
+logging_steps: 1
+save_steps: 50
diff --git a/examples/ebft/qwen35-9b-ebft-structured.yaml b/examples/ebft/qwen35-9b-ebft-structured.yaml
new file mode 100644
index 000000000..e79fb5fbf
--- /dev/null
+++ b/examples/ebft/qwen35-9b-ebft-structured.yaml
@@ -0,0 +1,82 @@
+# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
+#
+# Prerequisites:
+# 1. Start vLLM on GPU 0:
+# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
+#
+# 2. Run training on GPU 1:
+# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
+# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
+
+base_model: Qwen/Qwen3.5-9B
+
+rl: ebft
+
+ebft:
+ feature_layers: [0.25, 0.5, 0.75]
+ embed_method: last_token
+ use_whitening: false
+ alignment_coef: 1.0
+ diversity_coef: 1.0
+ ce_coef: 0.0
+
+trl:
+ num_generations: 4
+ max_completion_length: 256
+ temperature: 0.7
+ use_vllm: true
+ vllm_server_host: 0.0.0.0
+ vllm_server_port: 8000
+ scale_rewards: true
+ loss_type: grpo
+ epsilon: 0.2
+ generation_kwargs:
+ stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
+ chat_template_kwargs:
+ enable_thinking: false
+ vllm_server_timeout: 300
+
+vllm:
+ gpu_memory_utilization: 0.7
+ max_model_len: 2048
+ serve_module: axolotl.scripts.vllm_serve_lora
+ enforce_eager: true
+
+datasets:
+ - path: nvidia/OpenCodeInstruct
+ type: ebft_opencode.transform
+ split: train[:500]
+
+sequence_len: 1024
+micro_batch_size: 1
+gradient_accumulation_steps: 4
+num_epochs: 1
+max_steps: 10
+
+learning_rate: 3.0e-6
+optimizer: adamw_torch_fused
+lr_scheduler: cosine
+warmup_steps: 3
+weight_decay: 0.01
+
+adapter: lora
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.0
+# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
+# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
+lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
+
+bf16: auto
+flash_attention: true
+gradient_checkpointing: true
+
+special_tokens:
+ pad_token: "<|endoftext|>"
+
+val_set_size: 0.0
+output_dir: ./outputs/ebft-qwen35-9b-structured
+
+wandb_project: ebft
+logging_steps: 1
+save_steps: 50
diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py
index 10db23878..2180a9e7f 100644
--- a/src/axolotl/cli/vllm_serve.py
+++ b/src/axolotl/cli/vllm_serve.py
@@ -38,18 +38,14 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
- # Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
+ # Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
+ # We default to axolotl's serve module instead of TRL's because TRL's sends
+ # truncate_prompt_tokens which is unsupported in vLLM 0.17+.
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
- if (
- serve_module is None
- and getattr(cfg, "trl", None)
- and getattr(cfg.trl, "vllm_lora_sync", False)
- ):
- serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
- serve_module = "trl.scripts.vllm_serve"
+ serve_module = "axolotl.scripts.vllm_serve_lora"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -79,6 +75,12 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
+ cli_enforce_eager = cli_args.get("enforce_eager")
+ cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
+ raw_enforce_eager = (
+ cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
+ )
+ enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
base_kwargs = dict(
model=model,
tensor_parallel_size=tensor_parallel_size,
@@ -89,6 +91,7 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
+ enforce_eager=enforce_eager,
)
# Use LoRAScriptArguments when serving with native LoRA support
@@ -98,6 +101,10 @@ def do_vllm_serve(
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
+ # Disable native LoRA in vLLM if not using vllm_lora_sync
+ # (merged weight sync via batch_update doesn't need vLLM LoRA mode)
+ if not getattr(cfg.trl, "vllm_lora_sync", False):
+ lora_kwargs["enable_lora"] = False
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(
diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py
index c95ddb80e..b661e74c9 100644
--- a/src/axolotl/common/datasets.py
+++ b/src/axolotl/common/datasets.py
@@ -118,7 +118,7 @@ def load_preference_datasets(
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
total_num_steps: int | None = None
- if cfg.rl is not RLType.GRPO:
+ if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py
index f7bf110cc..89d4c9ff7 100644
--- a/src/axolotl/core/builders/rl.py
+++ b/src/axolotl/core/builders/rl.py
@@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
+ elif self.cfg.rl is RLType.EBFT:
+ from axolotl.core.trainers.ebft import EBFTStrategy
+
+ trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
+ trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
+
+ elif self.cfg.rl is RLType.EBFT:
+ from axolotl.core.trainers.ebft import EBFTStrategy
+
+ training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
+ training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
+ blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if (
self.cfg.adapter
and self.peft_config
- and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
+ and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py
index 22d8b64f6..cdc09ae0a 100644
--- a/src/axolotl/core/trainers/__init__.py
+++ b/src/axolotl/core/trainers/__init__.py
@@ -4,6 +4,8 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
+from .ebft.strided import AxolotlStridedEBFTTrainer
+from .ebft.trainer import AxolotlEBFTTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,
diff --git a/src/axolotl/core/trainers/ebft/__init__.py b/src/axolotl/core/trainers/ebft/__init__.py
new file mode 100644
index 000000000..23b61fbe6
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/__init__.py
@@ -0,0 +1,213 @@
+"""EBFT (Energy-Based Fine-Tuning) Strategy for training
+
+Two modes:
+- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
+- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
+"""
+
+from typing import Any
+
+from axolotl.core.trainers.ebft.args import (
+ AxolotlAsyncEBFTConfig,
+ AxolotlEBFTConfig,
+ AxolotlStridedEBFTConfig,
+)
+from axolotl.utils.dict import DictDefault
+
+
+def _get_ebft_mode(cfg: DictDefault) -> str:
+ """Determine EBFT mode from config."""
+ if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
+ return cfg.ebft.mode
+ return "structured"
+
+
+class EBFTStrategy:
+ """Strategy for EBFT training — dispatches between structured and strided modes."""
+
+ @classmethod
+ def get_trainer_class(cls, cfg: DictDefault | None = None):
+ mode = _get_ebft_mode(cfg) if cfg else "structured"
+ if mode == "strided":
+ from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
+
+ return AxolotlStridedEBFTTrainer
+
+ # Structured mode: async or sync
+ use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
+ if use_async:
+ from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
+
+ return AxolotlAsyncEBFTTrainer
+ from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
+
+ return AxolotlEBFTTrainer
+
+ @classmethod
+ def get_training_args_class(cls, cfg: DictDefault | None = None):
+ mode = _get_ebft_mode(cfg) if cfg else "structured"
+ if mode == "strided":
+ return AxolotlStridedEBFTConfig
+
+ # Structured mode: async or sync config
+ use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
+ if use_async:
+ return AxolotlAsyncEBFTConfig
+ return AxolotlEBFTConfig
+
+ @classmethod
+ def is_strided(cls, cfg: DictDefault) -> bool:
+ return _get_ebft_mode(cfg) == "strided"
+
+ @classmethod
+ def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
+ """Map axolotl YAML config fields to training args kwargs."""
+ kwargs: dict[str, Any] = {}
+ mode = _get_ebft_mode(cfg)
+
+ # Common EBFT fields
+ ebft = cfg.ebft
+ if ebft:
+ if ebft.feature_layers is not None:
+ kwargs["ebft_feature_layers"] = ebft.feature_layers
+ if ebft.embed_method is not None:
+ kwargs["ebft_embed_method"] = ebft.embed_method
+ if ebft.use_whitening is not None:
+ kwargs["ebft_use_whitening"] = ebft.use_whitening
+ if ebft.alignment_coef is not None:
+ kwargs["ebft_alignment_coef"] = ebft.alignment_coef
+ if ebft.diversity_coef is not None:
+ kwargs["ebft_diversity_coef"] = ebft.diversity_coef
+ if ebft.ce_coef is not None:
+ kwargs["ebft_ce_coef"] = ebft.ce_coef
+ if getattr(ebft, "adaptive_max_tokens", None) is not None:
+ kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
+ if getattr(ebft, "gt_length_multiplier", None) is not None:
+ kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
+
+ if mode == "strided":
+ # Strided-specific fields
+ if ebft:
+ if ebft.stride is not None:
+ kwargs["ebft_stride"] = ebft.stride
+ if ebft.context_length is not None:
+ kwargs["ebft_context_length"] = ebft.context_length
+ if ebft.generate_max_len is not None:
+ kwargs["ebft_generate_max_len"] = ebft.generate_max_len
+ if ebft.n_samples_per_prompt is not None:
+ kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
+ if ebft.temperature is not None:
+ kwargs["ebft_temperature"] = ebft.temperature
+ if ebft.top_p is not None:
+ kwargs["ebft_top_p"] = ebft.top_p
+ if ebft.rl_coef is not None:
+ kwargs["ebft_rl_coef"] = ebft.rl_coef
+ if ebft.advantage_estimator is not None:
+ kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
+ if ebft.min_completion_prefix is not None:
+ kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
+ else:
+ # Structured mode: map TRL config fields
+ trl = cfg.trl
+ if trl:
+ if trl.use_vllm:
+ kwargs["use_vllm"] = trl.use_vllm
+ if trl.vllm_mode:
+ kwargs["vllm_mode"] = trl.vllm_mode
+ if trl.vllm_mode == "colocate":
+ kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
+ vllm_cfg = cfg.vllm
+ if vllm_cfg:
+ kwargs["vllm_gpu_memory_utilization"] = (
+ vllm_cfg.gpu_memory_utilization
+ )
+ kwargs["vllm_tensor_parallel_size"] = (
+ vllm_cfg.tensor_parallel_size
+ )
+ kwargs["vllm_server_host"] = trl.vllm_server_host or (
+ trl.vllm.host if trl.vllm else None
+ )
+ kwargs["vllm_server_port"] = trl.vllm_server_port or (
+ trl.vllm.port if trl.vllm else None
+ )
+ if trl.vllm_server_timeout:
+ kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
+
+ if trl.num_generations:
+ kwargs["num_generations"] = trl.num_generations
+ if trl.max_completion_length is not None:
+ kwargs["max_completion_length"] = trl.max_completion_length
+ if trl.temperature is not None:
+ kwargs["temperature"] = trl.temperature
+ if trl.top_p is not None:
+ kwargs["top_p"] = trl.top_p
+ if trl.top_k is not None:
+ kwargs["top_k"] = trl.top_k
+ if trl.min_p is not None:
+ kwargs["min_p"] = trl.min_p
+ if trl.num_iterations is not None:
+ kwargs["num_iterations"] = trl.num_iterations
+ if trl.epsilon is not None:
+ kwargs["epsilon"] = trl.epsilon
+ if trl.epsilon_high is not None:
+ kwargs["epsilon_high"] = trl.epsilon_high
+ if trl.scale_rewards is not None:
+ kwargs["scale_rewards"] = trl.scale_rewards
+ if trl.loss_type is not None:
+ kwargs["loss_type"] = trl.loss_type
+ if trl.mask_truncated_completions is not None:
+ kwargs["mask_truncated_completions"] = (
+ trl.mask_truncated_completions
+ )
+ if trl.log_completions is not None:
+ kwargs["log_completions"] = trl.log_completions
+ if trl.num_completions_to_print is not None:
+ kwargs["num_completions_to_print"] = trl.num_completions_to_print
+ if trl.sync_ref_model:
+ kwargs["sync_ref_model"] = trl.sync_ref_model
+ if trl.repetition_penalty is not None:
+ kwargs["repetition_penalty"] = trl.repetition_penalty
+ if trl.generation_kwargs is not None:
+ kwargs["generation_kwargs"] = trl.generation_kwargs
+ if trl.chat_template_kwargs is not None:
+ kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
+
+ # Async prefetch fields (only pass when enabled — sync config doesn't have these)
+ if getattr(trl, "async_prefetch", False):
+ kwargs["async_prefetch"] = trl.async_prefetch
+ if getattr(trl, "vllm_sync_interval", None) is not None:
+ kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
+ if getattr(trl, "vllm_lora_sync", False):
+ kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
+
+ return kwargs
+
+ @classmethod
+ def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
+ return []
+
+ @classmethod
+ def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
+ return {}
+
+ @classmethod
+ def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
+ mode = _get_ebft_mode(cfg) if cfg else "structured"
+ if mode == "strided":
+ return [
+ "dataset_num_proc",
+ "max_length",
+ "max_prompt_length",
+ "include_tokens_per_second",
+ "beta",
+ ]
+ return [
+ "dataset_num_proc",
+ "max_length",
+ "include_tokens_per_second",
+ "max_prompt_length",
+ ]
+
+ @classmethod
+ def get_collator(cls, *args, **kwargs):
+ return None
diff --git a/src/axolotl/core/trainers/ebft/args.py b/src/axolotl/core/trainers/ebft/args.py
new file mode 100644
index 000000000..4a31b6b6b
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/args.py
@@ -0,0 +1,133 @@
+"""
+EBFT-specific training arguments.
+
+Two config classes:
+- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
+- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
+"""
+
+from dataclasses import dataclass, field
+
+from transformers import TrainingArguments
+from trl import GRPOConfig
+
+from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
+from axolotl.core.training_args import AxolotlTrainingMixins
+
+
+# -- Shared EBFT fields as a mixin --
+@dataclass
+class EBFTFieldsMixin:
+ """Common fields shared between structured and strided EBFT configs."""
+
+ ebft_feature_layers: list[float] = field(
+ default_factory=lambda: [0.25, 0.5, 0.75],
+ metadata={"help": "Fractional layer depths for feature extraction"},
+ )
+ ebft_embed_method: str = field(
+ default="last_token",
+ metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
+ )
+ ebft_use_whitening: bool = field(
+ default=False,
+ metadata={"help": "Apply SVD whitening to feature embeddings"},
+ )
+ ebft_alignment_coef: float = field(
+ default=1.0,
+ metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
+ )
+ ebft_diversity_coef: float = field(
+ default=1.0,
+ metadata={"help": "Coefficient for diversity penalty"},
+ )
+ ebft_ce_coef: float = field(
+ default=0.0,
+ metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
+ )
+ ebft_adaptive_max_tokens: bool = field(
+ default=True,
+ metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
+ )
+ ebft_gt_length_multiplier: float = field(
+ default=1.5,
+ metadata={
+ "help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
+ },
+ )
+
+
+# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
+@dataclass
+class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
+ """EBFT config for structured QA data — extends GRPOConfig."""
+
+ vllm_lora_sync: bool = field(
+ default=False,
+ metadata={
+ "help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
+ },
+ )
+
+
+# -- Async structured mode: extends FastAsyncGRPOConfig --
+@dataclass
+class AxolotlAsyncEBFTConfig(
+ EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
+):
+ """EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
+
+ Includes all async fields: async_prefetch, vllm_lora_sync,
+ skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
+ """
+
+ vllm_lora_sync: bool = field(
+ default=False,
+ metadata={
+ "help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
+ },
+ )
+
+
+# -- Strided mode: extends TrainingArguments for unstructured text --
+@dataclass
+class AxolotlStridedEBFTConfig(
+ EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
+):
+ """EBFT config for unstructured text with strided block-parallel generation."""
+
+ ebft_stride: int = field(
+ default=8,
+ metadata={"help": "Stride between anchor points (in tokens)"},
+ )
+ ebft_context_length: int = field(
+ default=8,
+ metadata={"help": "Context window size for each block"},
+ )
+ ebft_generate_max_len: int = field(
+ default=8,
+ metadata={"help": "Number of tokens to generate per block"},
+ )
+ ebft_n_samples_per_prompt: int = field(
+ default=4,
+ metadata={"help": "Number of independent rollouts per document"},
+ )
+ ebft_temperature: float = field(
+ default=0.6,
+ metadata={"help": "Sampling temperature for strided generation"},
+ )
+ ebft_top_p: float = field(
+ default=1.0,
+ metadata={"help": "Top-p nucleus sampling threshold"},
+ )
+ ebft_rl_coef: float = field(
+ default=1.0,
+ metadata={"help": "RL policy gradient loss coefficient"},
+ )
+ ebft_advantage_estimator: str = field(
+ default="rloo",
+ metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
+ )
+ ebft_min_completion_prefix: int = field(
+ default=0,
+ metadata={"help": "Minimum tokens into completion before placing anchors"},
+ )
diff --git a/src/axolotl/core/trainers/ebft/kernels.py b/src/axolotl/core/trainers/ebft/kernels.py
new file mode 100644
index 000000000..d1d35feb4
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/kernels.py
@@ -0,0 +1,308 @@
+"""
+Fused Triton kernels for strided EBFT.
+
+These kernels eliminate intermediate tensor materializations that dominate
+the elementwise/fill category (~40% of CUDA time in profiling).
+
+Kernels:
+ 1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
+ 2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
+ 3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+# ---------------------------------------------------------------------------
+# 1. Fused log_softmax + gather (selective log softmax)
+# ---------------------------------------------------------------------------
+# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
+# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
+# This avoids materializing the full (B, S, V) log_softmax output.
+
+
+@triton.jit
+def _fused_log_softmax_gather_kernel(
+ logits_ptr, # (B*S, V) row-major
+ labels_ptr, # (B*S,) int64
+ output_ptr, # (B*S,) float32
+ V: tl.constexpr, # vocab size
+ BLOCK_V: tl.constexpr, # tile width over vocab
+):
+ """Compute log_softmax(logits)[label] for each row without materializing full output."""
+ row = tl.program_id(0)
+
+ logits_row_ptr = logits_ptr + row * V
+ label = tl.load(labels_ptr + row)
+
+ # Pass 1: find max for numerical stability
+ max_val = -float("inf")
+ for v_start in range(0, V, BLOCK_V):
+ v_offsets = v_start + tl.arange(0, BLOCK_V)
+ mask = v_offsets < V
+ vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
+ max_val = tl.maximum(max_val, tl.max(vals, axis=0))
+
+ # Pass 2: compute sum(exp(x - max))
+ sum_exp = 0.0
+ for v_start in range(0, V, BLOCK_V):
+ v_offsets = v_start + tl.arange(0, BLOCK_V)
+ mask = v_offsets < V
+ vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
+ sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
+
+ log_sum_exp = tl.log(sum_exp)
+
+ # Gather: log_softmax[label] = logits[label] - max - log_sum_exp
+ target_logit = tl.load(logits_row_ptr + label)
+ result = target_logit - max_val - log_sum_exp
+
+ tl.store(output_ptr + row, result)
+
+
+def fused_log_softmax_gather(
+ logits: torch.Tensor, labels: torch.Tensor
+) -> torch.Tensor:
+ """Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
+
+ Args:
+ logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
+ labels: (B, S) or (B*S,) int64 tensor of target indices
+
+ Returns:
+ (B, S) or (B*S,) float32 tensor of selected log probabilities
+ """
+ orig_shape = logits.shape[:-1]
+ V = logits.shape[-1]
+ logits_2d = logits.reshape(-1, V).contiguous()
+ labels_1d = labels.reshape(-1).contiguous()
+ n_rows = logits_2d.shape[0]
+
+ output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
+
+ # Choose BLOCK_V: must be power of 2, large enough for good occupancy
+ BLOCK_V = min(triton.next_power_of_2(V), 65536)
+
+ _fused_log_softmax_gather_kernel[(n_rows,)](
+ logits_2d,
+ labels_1d,
+ output,
+ V=V,
+ BLOCK_V=BLOCK_V,
+ )
+
+ return output.view(orig_shape)
+
+
+# ---------------------------------------------------------------------------
+# 2. Fused masked REINFORCE loss reduction
+# ---------------------------------------------------------------------------
+# Instead of: (-logp * adv * mask).sum() / mask.sum()
+# We do the masked multiply-accumulate in one kernel, returning (sum, count).
+
+
+@triton.jit
+def _fused_reinforce_loss_kernel(
+ logps_ptr, # (N,) float32 per-token log probs
+ advs_ptr, # (N,) float32 advantages
+ mask_ptr, # (N,) bool action mask
+ partial_sum_ptr, # (n_blocks,) partial sums
+ partial_cnt_ptr, # (n_blocks,) partial counts
+ N: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ block_id = tl.program_id(0)
+ offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
+ valid = offsets < N
+
+ logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
+ advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
+ m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
+
+ # -logp * advantage * mask
+ loss = -logps * advs * m
+ block_sum = tl.sum(loss, axis=0)
+ block_cnt = tl.sum(m, axis=0)
+
+ tl.store(partial_sum_ptr + block_id, block_sum)
+ tl.store(partial_cnt_ptr + block_id, block_cnt)
+
+
+def fused_reinforce_loss(
+ per_token_logps: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: torch.Tensor,
+) -> torch.Tensor:
+ """Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
+
+ All inputs should be flat or will be flattened. Returns scalar loss tensor.
+ """
+ logps_flat = per_token_logps.reshape(-1).contiguous()
+ advs_flat = advantages.reshape(-1).contiguous()
+ mask_flat = action_mask.reshape(-1).contiguous()
+ N = logps_flat.shape[0]
+
+ BLOCK_N = 1024
+ n_blocks = triton.cdiv(N, BLOCK_N)
+
+ partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
+ partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
+
+ _fused_reinforce_loss_kernel[(n_blocks,)](
+ logps_flat,
+ advs_flat,
+ mask_flat,
+ partial_sum,
+ partial_cnt,
+ N=N,
+ BLOCK_N=BLOCK_N,
+ )
+
+ total_sum = partial_sum.sum()
+ total_cnt = partial_cnt.sum().clamp(min=1)
+ return total_sum / total_cnt
+
+
+# ---------------------------------------------------------------------------
+# 3. Fused cosine similarity (batched, for EBFT rewards)
+# ---------------------------------------------------------------------------
+# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
+# we fuse the dot product, norm computation, and division into one kernel.
+
+
+@triton.jit
+def _fused_cosine_sim_kernel(
+ a_ptr, # (N, D) first set of vectors
+ b_ptr, # (N, D) second set of vectors
+ out_ptr, # (N,) cosine similarities
+ D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ row = tl.program_id(0)
+ a_row_ptr = a_ptr + row * D
+ b_row_ptr = b_ptr + row * D
+
+ dot = 0.0
+ norm_a = 0.0
+ norm_b = 0.0
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offsets = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offsets < D
+ a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
+ b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
+
+ dot += tl.sum(a_vals * b_vals, axis=0)
+ norm_a += tl.sum(a_vals * a_vals, axis=0)
+ norm_b += tl.sum(b_vals * b_vals, axis=0)
+
+ denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
+ denom = tl.where(denom > 1e-8, denom, 1e-8)
+ result = dot / denom
+
+ tl.store(out_ptr + row, result)
+
+
+def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """Compute cosine similarity along the last dimension.
+
+ Args:
+ a, b: (..., D) tensors of the same shape
+
+ Returns:
+ (...,) tensor of cosine similarities
+ """
+ orig_shape = a.shape[:-1]
+ D = a.shape[-1]
+ a_2d = a.reshape(-1, D).contiguous()
+ b_2d = b.reshape(-1, D).contiguous()
+ N = a_2d.shape[0]
+
+ output = torch.empty(N, device=a.device, dtype=torch.float32)
+
+ BLOCK_D = min(triton.next_power_of_2(D), 4096)
+
+ _fused_cosine_sim_kernel[(N,)](
+ a_2d,
+ b_2d,
+ output,
+ D=D,
+ BLOCK_D=BLOCK_D,
+ )
+
+ return output.view(orig_shape)
+
+
+# ---------------------------------------------------------------------------
+# 4. Fused pairwise diversity penalty
+# ---------------------------------------------------------------------------
+# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
+# We compute the pairwise dot products and exclusion in one kernel.
+
+
+@triton.jit
+def _fused_diversity_kernel(
+ emb_ptr, # (B, N, D) embeddings, row-major
+ out_ptr, # (B, N) diversity penalties
+ N: tl.constexpr, # n_samples
+ D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """For each (b, i), compute mean dot product to all j != i."""
+ b = tl.program_id(0)
+ i = tl.program_id(1)
+
+ # Pointer to emb[b, i, :]
+ emb_bi_ptr = emb_ptr + (b * N + i) * D
+
+ total_sim = 0.0
+ for j in range(N):
+ emb_bj_ptr = emb_ptr + (b * N + j) * D
+
+ dot = 0.0
+ for d_start in range(0, D, BLOCK_D):
+ d_offsets = d_start + tl.arange(0, BLOCK_D)
+ d_mask = d_offsets < D
+ a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
+ tl.float32
+ )
+ b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
+ tl.float32
+ )
+ dot += tl.sum(a_vals * b_vals, axis=0)
+
+ # Exclude self-similarity (j == i)
+ is_other = j != i
+ total_sim += dot * is_other
+
+ result = total_sim / (N - 1)
+ tl.store(out_ptr + b * N + i, result)
+
+
+def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
+ """Compute mean pairwise dot product (excluding self) per sample.
+
+ Args:
+ embeddings: (B, N, D) tensor where N is n_samples
+
+ Returns:
+ (B, N) tensor of diversity penalties
+ """
+ B, N, D = embeddings.shape
+ embeddings = embeddings.contiguous()
+ output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
+ if N <= 1:
+ return output # diversity is 0 when there's only one sample
+
+ BLOCK_D = min(triton.next_power_of_2(D), 4096)
+
+ _fused_diversity_kernel[(B, N)](
+ embeddings,
+ output,
+ N=N,
+ D=D,
+ BLOCK_D=BLOCK_D,
+ )
+
+ return output
diff --git a/src/axolotl/core/trainers/ebft/rewards.py b/src/axolotl/core/trainers/ebft/rewards.py
new file mode 100644
index 000000000..993650bbe
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/rewards.py
@@ -0,0 +1,264 @@
+"""
+Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
+
+Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
+Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
+ (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+@torch.no_grad()
+def extract_hidden_states(
+ model: nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_indices: list[int],
+ batch_size: int | None = None,
+) -> torch.Tensor:
+ """
+ Forward pass through model, extracting and concatenating hidden states
+ at specified layer indices.
+
+ Args:
+ model: The frozen feature network
+ input_ids: (B, S) token ids
+ attention_mask: (B, S) attention mask
+ layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
+ batch_size: If set, process in chunks to reduce peak memory
+
+ Returns:
+ Concatenated hidden states: (B, S, num_layers * H)
+ """
+ if batch_size is None:
+ batch_size = input_ids.shape[0]
+
+ # Use the inner transformer body (skips lm_head) when available.
+ # This avoids the expensive hidden_dim × vocab_size matmul whose
+ # output (logits) is never used — only hidden_states are needed.
+ body = getattr(model, "model", None)
+ if body is not None and hasattr(body, "forward"):
+ forward_model = body
+ else:
+ forward_model = model
+
+ all_features = []
+ for i in range(0, input_ids.shape[0], batch_size):
+ chunk_ids = input_ids[i : i + batch_size]
+ chunk_mask = attention_mask[i : i + batch_size]
+
+ outputs = forward_model(
+ chunk_ids,
+ attention_mask=chunk_mask,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ # hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
+ # index 0 is the embedding layer output
+ hidden_states = outputs.hidden_states
+ chunk_features = []
+ for idx in layer_indices:
+ chunk_features.append(hidden_states[idx])
+
+ # Concatenate across feature dimension: (B, S, num_layers * H)
+ all_features.append(torch.cat(chunk_features, dim=-1))
+
+ return torch.cat(all_features, dim=0)
+
+
+def apply_embed_method(
+ hidden_states: torch.Tensor,
+ method: str,
+ attention_mask: torch.Tensor | None = None,
+ prompt_lengths: torch.Tensor | None = None,
+) -> torch.Tensor:
+ """
+ Pool per-token hidden states into per-sequence embeddings.
+
+ Args:
+ hidden_states: (B, S, D) concatenated hidden states
+ method: One of "last_token", "mean_pooling", "completion_mean", "concat"
+ attention_mask: (B, S) mask for mean pooling
+ prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
+
+ Returns:
+ Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
+ (B, 3*D) for concat
+ """
+ if method == "last_token":
+ if attention_mask is not None:
+ # Find last non-padding position per sample
+ last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
+ return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
+ return hidden_states[:, -1, :]
+
+ if method == "mean_pooling":
+ if attention_mask is not None:
+ mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
+ return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
+ return hidden_states.mean(dim=1)
+
+ if method == "completion_mean":
+ # Mean pool over completion tokens only (exclude prompt)
+ if prompt_lengths is None:
+ raise ValueError("completion_mean requires prompt_lengths")
+ B, S, _ = hidden_states.shape
+ positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
+ comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
+ if attention_mask is not None:
+ comp_mask = comp_mask & attention_mask.bool()
+ mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
+ return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
+
+ if method == "concat":
+ B, S, D = hidden_states.shape
+ if attention_mask is not None:
+ valid_lens = attention_mask.sum(dim=1).long() # (B,)
+ else:
+ valid_lens = torch.full(
+ (B,), S, device=hidden_states.device, dtype=torch.long
+ )
+ # Compute quartile positions relative to valid length per sample
+ # First valid position index for each sample (handles right-padding)
+ q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
+ q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
+ q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
+ batch_idx = torch.arange(B, device=hidden_states.device)
+ return torch.cat(
+ [
+ hidden_states[batch_idx, q1],
+ hidden_states[batch_idx, q2],
+ hidden_states[batch_idx, q3],
+ ],
+ dim=-1,
+ )
+
+ raise ValueError(f"Unknown embed_method: {method}")
+
+
+@torch.no_grad()
+def get_alignment_rewards(
+ gen_embedding: torch.Tensor,
+ gt_embedding: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Compute alignment reward as cosine similarity between generated
+ and ground-truth feature embeddings.
+
+ Args:
+ gen_embedding: (B, D) generated sequence embeddings
+ gt_embedding: (B, D) ground-truth sequence embeddings
+ If num_generations > 1, gt_embedding should be repeated
+ to match gen_embedding's batch dim.
+
+ Returns:
+ Alignment rewards: (B,) cosine similarities in [-1, 1]
+ """
+ return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
+
+
+@torch.no_grad()
+def get_diversity_rewards(
+ gen_embedding: torch.Tensor,
+ num_generations: int,
+) -> torch.Tensor:
+ """
+ Compute diversity penalty as mean pairwise dot-product similarity
+ between samples from the same prompt (excluding self-similarity).
+
+ Args:
+ gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
+ num_generations: Number of generations per prompt
+
+ Returns:
+ Diversity penalties: (B,) mean similarity to other samples from same prompt
+ """
+ if num_generations <= 1:
+ return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
+
+ num_prompts = gen_embedding.shape[0] // num_generations
+
+ # Reshape to (num_prompts, num_generations, D)
+ reshaped = gen_embedding.view(num_prompts, num_generations, -1)
+
+ # Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
+ sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
+
+ # Zero out self-similarity (diagonal)
+ eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
+ sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
+
+ # Mean similarity to other samples: (num_prompts, num_generations)
+ diversity = sims.sum(dim=-1) / (num_generations - 1)
+
+ # Flatten back to (B,)
+ return diversity.view(-1)
+
+
+def whiten_embeddings_batched(
+ phi: torch.Tensor,
+ phi_gt: torch.Tensor,
+ whiten_tol: float = 1e-5,
+ normalize: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Whiten generated embeddings using SVD, then apply same transform to ground-truth.
+
+ Whitening decorrelates feature dimensions so no single direction dominates
+ the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
+
+ Note: Singular values scale with sqrt(B), so reward magnitudes are
+ batch-size dependent. This is acceptable because B = n_samples_per_prompt
+ which is fixed during training (typically 2-4).
+
+ Args:
+ phi: (B, D) generated embeddings (used to estimate covariance)
+ phi_gt: (B, D) ground-truth embeddings
+ whiten_tol: Tolerance for singular value cutoff
+ normalize: If True, L2-normalize after whitening
+
+ Returns:
+ Whitened (phi, phi_gt) tuple, each (B, D)
+ """
+ phi_f = phi.float()
+ phi_gt_f = phi_gt.float()
+
+ # Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
+ try:
+ U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
+ except torch._C._LinAlgError:
+ # Fallback: add small noise
+ noise = 1e-6 * phi_f.abs().mean()
+ try:
+ U, S, _ = torch.linalg.svd(
+ (phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
+ full_matrices=False,
+ )
+ except torch._C._LinAlgError:
+ if normalize:
+ return (
+ F.normalize(phi, p=2, dim=-1),
+ F.normalize(phi_gt, p=2, dim=-1),
+ )
+ return phi, phi_gt
+
+ U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
+
+ # Safe inverse of singular values
+ s_max = S.max()
+ inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
+
+ # W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
+ W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
+ phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
+ phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
+
+ if normalize:
+ phi_w = F.normalize(phi_w, p=2, dim=-1)
+ phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
+
+ return phi_w, phi_gt_w
diff --git a/src/axolotl/core/trainers/ebft/strided.py b/src/axolotl/core/trainers/ebft/strided.py
new file mode 100644
index 000000000..5cfc5b99b
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/strided.py
@@ -0,0 +1,1152 @@
+"""
+Strided block-parallel EBFT trainer for unstructured text data.
+
+This trainer implements the full EBFT algorithm from the paper, including
+strided block-parallel generation where multiple short rollouts are generated
+at different anchor points within a single document. This is essential for
+training on raw text data (code, prose, etc.) without prompt/completion splits.
+
+Uses torch flex_attention with a compiled block mask for efficient strided
+attention patterns. Falls back to eager attention with dense 4D masks when
+flex_attention is not available.
+
+Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
+ (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
+"""
+
+import contextlib
+import copy
+
+import torch
+import torch.nn.functional as F
+from transformers import Trainer
+
+from axolotl.core.trainers.ebft.kernels import ( # noqa: F401 — available for future use
+ fused_cosine_similarity,
+ fused_diversity_penalty,
+ fused_log_softmax_gather,
+ fused_reinforce_loss,
+)
+from axolotl.core.trainers.ebft.rewards import (
+ whiten_embeddings_batched,
+)
+from axolotl.core.trainers.mixins import (
+ DistributedParallelMixin,
+ RngLoaderMixin,
+ SchedulerMixin,
+)
+from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
+from axolotl.utils.logging import get_logger
+
+LOG = get_logger(__name__)
+
+# Check flex_attention availability
+_FLEX_ATTENTION_AVAILABLE = False
+try:
+ from torch.nn.attention.flex_attention import (
+ create_block_mask,
+ )
+
+ _FLEX_ATTENTION_AVAILABLE = True
+except ImportError:
+ pass
+
+
+def _patch_flex_attention_dtype():
+ """
+ Patch HF's flex_attention_forward to cast q/k/v to a uniform dtype.
+
+ This fixes the incompatibility between flex_attention and gradient
+ checkpointing, where recomputation can produce q/k in float32 while
+ v stays in bfloat16. flex_attention requires all three to match.
+ """
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
+
+ original_fn = ALL_ATTENTION_FUNCTIONS.get("flex_attention")
+ if original_fn is None:
+ return
+
+ def patched_flex_attention_forward(
+ module, query, key, value, attention_mask, **kwargs
+ ):
+ # Cast q/k/v to the same dtype (use value's dtype as reference,
+ # since model weights stay in the original dtype)
+ target_dtype = value.dtype
+ if query.dtype != target_dtype:
+ query = query.to(target_dtype)
+ if key.dtype != target_dtype:
+ key = key.to(target_dtype)
+ return original_fn(module, query, key, value, attention_mask, **kwargs)
+
+ ALL_ATTENTION_FUNCTIONS["flex_attention"] = patched_flex_attention_forward
+
+
+@contextlib.contextmanager
+def override_attn_implementation(model, implementation: str):
+ """Temporarily override a model's attention implementation.
+
+ Useful for forcing eager attention during generation (where sequence
+ lengths change each step, causing dynamo recompiles) while keeping
+ flex_attention for the fixed-size training forward pass.
+
+ Usage::
+
+ with override_attn_implementation(model, "eager"):
+ output = model(input_ids, attention_mask=dense_4d_mask, ...)
+ """
+ config = getattr(model, "config", None)
+ if config is None or not hasattr(config, "_attn_implementation"):
+ yield
+ return
+
+ saved = config._attn_implementation
+ config._attn_implementation = implementation
+ try:
+ yield
+ finally:
+ config._attn_implementation = saved
+
+
+# ---------------------------------------------------------------------------
+# Strided attention mask builders
+# ---------------------------------------------------------------------------
+
+
+def _strided_mask_mod(
+ b,
+ h,
+ q_idx,
+ kv_idx,
+ prompt_length,
+ context_length,
+ max_generation_length,
+ stride,
+ num_blocks,
+):
+ """
+ Mask mod function for flex_attention's create_block_mask.
+
+ Defines the strided block-parallel attention pattern:
+ - Prompt region: standard causal
+ - Generated region: each block attends to its context window + same-block predecessors
+ """
+ # --- Prompt region: standard causal ---
+ is_prompt_q = q_idx < prompt_length
+ is_prompt_kv = kv_idx < prompt_length
+ prompt_causal = is_prompt_q & is_prompt_kv & (q_idx >= kv_idx)
+
+ # --- Generated region ---
+ is_gen_q = ~is_prompt_q
+ # Which generation step and block does this query token belong to?
+ gen_offset = q_idx - prompt_length
+ gen_step = gen_offset // num_blocks
+ block_idx = gen_offset % num_blocks
+
+ # Context window end for this block.
+ # Note: if prompt_length < max_generation_length, context_end clamps to 0 for all
+ # blocks. This is safe because compute_loss guards with num_blocks <= 0 → zero loss.
+ context_end = torch.clamp(
+ block_idx * stride + context_length,
+ max=prompt_length - max_generation_length,
+ )
+
+ # Rule 1: Generated token attends to its context window in the prompt
+ in_context = is_gen_q & is_prompt_kv & (kv_idx < context_end)
+
+ # Rule 2: Self-attention
+ is_self = q_idx == kv_idx
+
+ # Rule 3: Attend to earlier tokens in the SAME block (same block_idx, earlier gen_step)
+ is_gen_kv = ~is_prompt_kv
+ kv_gen_offset = kv_idx - prompt_length
+ kv_gen_step = kv_gen_offset // num_blocks
+ kv_block_idx = kv_gen_offset % num_blocks
+ same_block_prev = (
+ is_gen_q & is_gen_kv & (kv_block_idx == block_idx) & (kv_gen_step < gen_step)
+ )
+
+ return prompt_causal | in_context | is_self | same_block_prev
+
+
+def create_strided_block_mask(
+ prompt_length: int,
+ context_length: int,
+ max_generation_length: int,
+ stride: int,
+ num_blocks: int,
+ full_sequence_length: int,
+ batch_size: int,
+ num_heads: int,
+ device: torch.device,
+):
+ """
+ Create a BlockMask for flex_attention using the strided EBFT pattern.
+
+ Returns a BlockMask that can be passed directly to model.forward()
+ when using attn_implementation="flex_attention".
+
+ Parameters that vary across training steps (context_length, num_blocks)
+ are captured as tensors so torch.compile/dynamo treats them as dynamic
+ values rather than guarding on literal int values (which causes recompiles).
+ """
+ # Wrap ALL mask params as 0-d tensors to prevent dynamo from guarding
+ # on their int values. Without this, each new anchor_offset or num_blocks
+ # triggers a recompile until the limit is hit → unfused fallback → OOM.
+ _prompt_length = torch.tensor(prompt_length, device=device)
+ _context_length = torch.tensor(context_length, device=device)
+ _max_gen_len = torch.tensor(max_generation_length, device=device)
+ _stride = torch.tensor(stride, device=device)
+ _num_blocks = torch.tensor(num_blocks, device=device)
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ return _strided_mask_mod(
+ b,
+ h,
+ q_idx,
+ kv_idx,
+ prompt_length=_prompt_length,
+ context_length=_context_length,
+ max_generation_length=_max_gen_len,
+ stride=_stride,
+ num_blocks=_num_blocks,
+ )
+
+ block_mask = create_block_mask(
+ mask_mod,
+ B=batch_size,
+ H=None, # broadcast across heads
+ Q_LEN=full_sequence_length,
+ KV_LEN=full_sequence_length,
+ device=device,
+ )
+ return block_mask
+
+
+def build_strided_position_ids(
+ full_sequence_length: int,
+ prompt_length: int,
+ context_length: int,
+ generation_step: int,
+ stride: int,
+ num_blocks: int,
+ device: torch.device,
+ batch_size: int = 1,
+):
+ """Build position IDs for strided generation (shared between flex and eager modes)."""
+ position_ids = torch.empty(
+ (batch_size, full_sequence_length), dtype=torch.long, device=device
+ )
+ position_ids[:, :prompt_length] = torch.arange(prompt_length, device=device)
+
+ block_starting_positions = (
+ torch.arange(num_blocks, device=device) * stride + context_length
+ )
+ for gen_step in range(generation_step):
+ start = prompt_length + gen_step * num_blocks
+ end = start + num_blocks
+ position_ids[:, start:end] = block_starting_positions + gen_step
+
+ return position_ids
+
+
+def build_strided_dense_mask_and_positions(
+ full_sequence_length: int,
+ prompt_length: int,
+ context_length: int,
+ generation_step: int,
+ max_generation_length: int,
+ stride: int,
+ num_blocks: int,
+ device: torch.device,
+ batch_size: int = 1,
+ dtype: torch.dtype = torch.bfloat16,
+):
+ """Build dense 4D attention mask (eager fallback) + position IDs."""
+ min_value = torch.finfo(dtype).min
+ attention_mask = torch.full(
+ (batch_size, 1, full_sequence_length, full_sequence_length),
+ min_value,
+ dtype=dtype,
+ device=device,
+ )
+
+ causal_mask = torch.tril(
+ torch.ones((prompt_length, prompt_length), dtype=torch.bool, device=device)
+ )
+ attention_mask[:, :, :prompt_length, :prompt_length].masked_fill_(
+ causal_mask.view(1, 1, prompt_length, prompt_length), 0.0
+ )
+
+ for gen_step in range(generation_step):
+ for block_idx in range(num_blocks):
+ gen_pos = prompt_length + gen_step * num_blocks + block_idx
+ context_end = min(
+ block_idx * stride + context_length,
+ prompt_length - max_generation_length,
+ )
+ attention_mask[:, 0, gen_pos, :context_end] = 0.0
+ attention_mask[:, 0, gen_pos, gen_pos] = 0.0
+ if gen_step > 0:
+ for prev_s in range(gen_step):
+ prev_pos = prompt_length + prev_s * num_blocks + block_idx
+ attention_mask[:, 0, gen_pos, prev_pos] = 0.0
+
+ position_ids = build_strided_position_ids(
+ full_sequence_length,
+ prompt_length,
+ context_length,
+ generation_step,
+ stride,
+ num_blocks,
+ device,
+ batch_size,
+ )
+ return attention_mask, position_ids
+
+
+# ---------------------------------------------------------------------------
+# Trainer
+# ---------------------------------------------------------------------------
+
+
+class AxolotlStridedEBFTTrainer(
+ RngLoaderMixin,
+ SchedulerMixin,
+ OptimizerMixin,
+ OptimizerInitMixin,
+ DistributedParallelMixin,
+ Trainer,
+):
+ """
+ Strided block-parallel EBFT trainer for unstructured text data.
+
+ Takes full text documents (no prompt/completion split needed), generates
+ short rollouts at multiple anchor points via strided attention, and trains
+ with feature-matching rewards.
+
+ When flex_attention is available (torch >= 2.5), uses compiled block masks
+ for efficient fused attention kernels. Otherwise falls back to eager
+ attention with dense 4D masks.
+ """
+
+ _tag_names = ["ebft", "strided", "axolotl"]
+
+ def __init__(self, model, args, train_dataset, **kwargs):
+ super().__init__(model=model, args=args, train_dataset=train_dataset, **kwargs)
+
+ # EBFT config
+ self.ebft_stride = getattr(args, "ebft_stride", 8)
+ self.ebft_context_length = getattr(args, "ebft_context_length", 8)
+ self.ebft_generate_max_len = getattr(args, "ebft_generate_max_len", 8)
+ self.ebft_n_samples = getattr(args, "ebft_n_samples_per_prompt", 4)
+ self.ebft_temperature = getattr(args, "ebft_temperature", 0.6)
+ self.ebft_top_p = getattr(args, "ebft_top_p", 1.0)
+ self.ebft_alignment_coef = getattr(args, "ebft_alignment_coef", 1.0)
+ self.ebft_diversity_coef = getattr(args, "ebft_diversity_coef", 1.0)
+ self.ebft_rl_coef = getattr(args, "ebft_rl_coef", 1.0)
+ self.ebft_ce_coef = getattr(args, "ebft_ce_coef", 0.0)
+ self.ebft_use_whitening = getattr(args, "ebft_use_whitening", False)
+ self.ebft_advantage_estimator = getattr(
+ args, "ebft_advantage_estimator", "rloo"
+ )
+ self.ebft_min_completion_prefix = getattr(args, "ebft_min_completion_prefix", 0)
+
+ # Validate config combinations
+ if self.ebft_use_whitening and self.ebft_diversity_coef > 0:
+ LOG.info(
+ "ebft: whitening + diversity enabled. Per paper Variant (i) (eq 49): "
+ "alignment uses cosine similarity (normalized), diversity uses raw dot product. "
+ "Both are bounded after whitening."
+ )
+ if self.ebft_n_samples == 1 and self.ebft_diversity_coef > 0:
+ LOG.warning(
+ "ebft.n_samples_per_prompt=1 with diversity_coef > 0: diversity penalty requires "
+ "multiple samples. Setting diversity_coef to 0."
+ )
+ self.ebft_diversity_coef = 0.0
+ if self.ebft_n_samples == 1 and self.ebft_advantage_estimator == "rloo":
+ LOG.warning(
+ "ebft.n_samples_per_prompt=1 with advantage_estimator='rloo': RLOO requires "
+ "multiple samples for baseline. Falling back to 'reinforce'."
+ )
+ self.ebft_advantage_estimator = "reinforce"
+
+ # Feature network config
+ feature_layers_frac = getattr(args, "ebft_feature_layers", [0.25, 0.5, 0.75])
+ embed_method = getattr(args, "ebft_embed_method", "last_token")
+ self.ebft_embed_method = embed_method
+
+ # Attention implementation selection
+ unwrapped = self.accelerator.unwrap_model(self.model)
+ self.use_flex_attention = (
+ _FLEX_ATTENTION_AVAILABLE and torch.cuda.is_available()
+ )
+
+ if self.use_flex_attention:
+ _patch_flex_attention_dtype()
+ LOG.info("Using flex_attention for strided EBFT (compiled block masks)")
+ if hasattr(unwrapped.config, "_attn_implementation"):
+ unwrapped.config._attn_implementation = "flex_attention"
+ self._num_heads = unwrapped.config.num_attention_heads
+ else:
+ LOG.info("Using eager attention for strided EBFT (dense 4D masks)")
+ if hasattr(unwrapped.config, "_attn_implementation"):
+ unwrapped.config._attn_implementation = "eager"
+ self._num_heads = None
+
+ # Feature network setup: either share weights with actor (PEFT models)
+ # or deepcopy (full-parameter models / multi-GPU).
+ first_param = next(unwrapped.parameters())
+ original_device = first_param.device
+ actor_gpu = (
+ original_device.index
+ if (original_device.type == "cuda" and original_device.index is not None)
+ else 0
+ )
+ visible_gpus = torch.cuda.device_count()
+
+ # Check if we can share weights (PEFT model on single GPU)
+ from peft import PeftModel
+
+ self._share_feature_weights = (
+ isinstance(unwrapped, PeftModel)
+ and visible_gpus == 1
+ and original_device.type != "meta"
+ )
+
+ if self._share_feature_weights:
+ # Share weights: use actor's base model with adapters disabled for
+ # feature extraction. Saves ~2.5 GB (no deepcopy of base weights).
+ self.feature_network = None # no separate network
+ self._feature_device = torch.device(f"cuda:{actor_gpu}")
+ self._feature_use_flex = self.use_flex_attention
+ LOG.info(
+ "Feature network shares actor weights (PEFT disable_adapter). "
+ f"Saving {sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9:.1f} GB"
+ )
+ elif visible_gpus > 1 and original_device.type != "meta":
+ # Multi-GPU: deepcopy to a separate device
+ self.feature_network = copy.deepcopy(unwrapped)
+ self.feature_network.to(dtype=torch.bfloat16)
+ self._feature_device = torch.device(
+ f"cuda:{(actor_gpu + 1) % visible_gpus}"
+ )
+ LOG.info(f"Creating frozen feature network on {self._feature_device}...")
+ self.feature_network.to(device=self._feature_device)
+ if _FLEX_ATTENTION_AVAILABLE and self._feature_device.type == "cuda":
+ if hasattr(self.feature_network.config, "_attn_implementation"):
+ self.feature_network.config._attn_implementation = "flex_attention"
+ self._feature_use_flex = True
+ LOG.info("Feature network using flex_attention")
+ else:
+ if hasattr(self.feature_network.config, "_attn_implementation"):
+ self.feature_network.config._attn_implementation = "eager"
+ self._feature_use_flex = False
+ for param in self.feature_network.parameters():
+ param.requires_grad = False
+ self.feature_network.eval()
+ elif original_device.type == "meta":
+ # FSDP2 with cpu_ram_efficient_loading
+ from transformers import AutoModelForCausalLM
+
+ feature_model_name = (
+ getattr(args, "model_name_or_path", None)
+ or unwrapped.config._name_or_path
+ )
+ self.feature_network = AutoModelForCausalLM.from_pretrained(
+ feature_model_name,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="eager",
+ )
+ self._feature_device = torch.device(f"cuda:{actor_gpu}")
+ self.feature_network.to(device=self._feature_device)
+ self._feature_use_flex = False
+ for param in self.feature_network.parameters():
+ param.requires_grad = False
+ self.feature_network.eval()
+ LOG.warning("Feature network loaded from pretrained (meta device)")
+ else:
+ # Single-GPU, non-PEFT: deepcopy on same device
+ self.feature_network = copy.deepcopy(unwrapped)
+ self.feature_network.to(dtype=torch.bfloat16)
+ self._feature_device = torch.device(f"cuda:{actor_gpu}")
+ self.feature_network.to(device=self._feature_device)
+ if _FLEX_ATTENTION_AVAILABLE:
+ if hasattr(self.feature_network.config, "_attn_implementation"):
+ self.feature_network.config._attn_implementation = "flex_attention"
+ self._feature_use_flex = True
+ else:
+ if hasattr(self.feature_network.config, "_attn_implementation"):
+ self.feature_network.config._attn_implementation = "eager"
+ self._feature_use_flex = False
+ for param in self.feature_network.parameters():
+ param.requires_grad = False
+ self.feature_network.eval()
+ LOG.info(
+ f"Created frozen feature network (deepcopy) on {self._feature_device}"
+ )
+
+ num_layers = unwrapped.config.num_hidden_layers
+ self.feature_layer_indices = [
+ int(frac * num_layers) for frac in feature_layers_frac
+ ]
+ LOG.info(
+ f"Strided EBFT: layers={self.feature_layer_indices}, "
+ f"stride={self.ebft_stride}, ctx={self.ebft_context_length}, "
+ f"gen_len={self.ebft_generate_max_len}, n_samples={self.ebft_n_samples}, "
+ f"embed={embed_method}, flex_attn={self.use_flex_attention}, "
+ f"min_completion_prefix={self.ebft_min_completion_prefix}"
+ )
+
+ def _build_strided_mask(
+ self,
+ full_seq_len,
+ seq_len,
+ generation_step,
+ num_blocks,
+ batch_size,
+ device,
+ dtype,
+ anchor_offset=None,
+ ):
+ """Build strided attention mask + position IDs using flex or eager.
+
+ Args:
+ anchor_offset: Position where anchors start. For unstructured data this
+ equals context_length; for structured data it equals
+ max(prompt_length + min_completion_prefix, context_length).
+ Defaults to self.ebft_context_length if not provided.
+ """
+ if anchor_offset is None:
+ anchor_offset = self.ebft_context_length
+
+ pos_ids = build_strided_position_ids(
+ full_seq_len,
+ seq_len,
+ anchor_offset,
+ generation_step,
+ self.ebft_stride,
+ num_blocks,
+ device,
+ batch_size,
+ )
+
+ if self.use_flex_attention:
+ block_mask = create_strided_block_mask(
+ prompt_length=seq_len,
+ context_length=anchor_offset,
+ max_generation_length=self.ebft_generate_max_len,
+ stride=self.ebft_stride,
+ num_blocks=num_blocks,
+ full_sequence_length=full_seq_len,
+ batch_size=batch_size,
+ num_heads=self._num_heads,
+ device=device,
+ )
+ return block_mask, pos_ids
+
+ dense_mask, pos_ids = build_strided_dense_mask_and_positions(
+ full_sequence_length=full_seq_len,
+ prompt_length=seq_len,
+ context_length=anchor_offset,
+ generation_step=generation_step,
+ max_generation_length=self.ebft_generate_max_len,
+ stride=self.ebft_stride,
+ num_blocks=num_blocks,
+ device=device,
+ batch_size=batch_size,
+ dtype=dtype,
+ )
+ return dense_mask, pos_ids
+
+ def compute_loss(
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
+ ):
+ """
+ Full strided EBFT training step.
+
+ 1. Take tokenized documents from inputs
+ 2. Generate n_samples short rollouts at strided anchor points
+ 3. Extract features from frozen network for both generated and GT blocks
+ 4. Compute alignment/diversity rewards per block
+ 5. Compute RLOO advantages
+ 6. Policy gradient loss on the strided forward pass
+
+ Supports both unstructured text (no prompt/completion split) and
+ structured data (prompt + completion with labels masking). For structured
+ data, anchors are placed only within the completion span.
+ """
+ outputs = None
+ device = next(model.parameters()).device
+ input_ids = inputs["input_ids"].to(device) # (B, seq_len)
+ B, seq_len = input_ids.shape
+
+ stride = self.ebft_stride
+ ctx_len = self.ebft_context_length
+ gen_len = self.ebft_generate_max_len
+ n_samples = self.ebft_n_samples
+
+ # --- Detect structured data and compute anchor_offset ---
+ # For structured data, anchors must start within the completion span.
+ # anchor_offset replaces ctx_len as the starting position for anchors.
+ is_structured = False
+ if "prompt_length" in inputs:
+ # Explicit prompt_length from dataset transform
+ prompt_lengths = inputs["prompt_length"].to(device) # (B,)
+ is_structured = True
+ elif "labels" in inputs:
+ # Derive prompt_length from labels: first position where labels != -100
+ labels = inputs["labels"].to(device)
+ non_masked = labels != -100
+ # prompt_length = index of first non-masked token (or seq_len if all masked)
+ has_completion = non_masked.any(dim=1)
+ prompt_lengths = torch.where(
+ has_completion,
+ non_masked.float().argmax(dim=1),
+ torch.tensor(seq_len, device=device, dtype=torch.float),
+ ).long()
+ is_structured = prompt_lengths.min().item() > 0
+
+ if is_structured:
+ # Use max prompt_length across batch for uniform anchor_offset
+ max_prompt_len = prompt_lengths.max().item()
+ anchor_offset = max(
+ max_prompt_len + self.ebft_min_completion_prefix, ctx_len
+ )
+ else:
+ anchor_offset = ctx_len
+
+ num_blocks = (seq_len - gen_len - anchor_offset) // stride + 1
+ if num_blocks <= 0:
+ LOG.warning(
+ f"Sequence too short for strided EBFT: seq_len={seq_len}, "
+ f"anchor_offset={anchor_offset}, "
+ f"need >= {gen_len + anchor_offset + stride}. Returning zero loss."
+ )
+ dummy_loss = input_ids.float().mean() * 0.0
+ return (dummy_loss, None) if return_outputs else dummy_loss
+
+ # --- Step 1: Generate strided blocks for n_samples ---
+ repeated_ids = input_ids.repeat_interleave(n_samples, dim=0)
+
+ with torch.no_grad():
+ full_sequences = self._generate_strided_blocks(
+ model,
+ repeated_ids,
+ num_blocks,
+ anchor_offset=anchor_offset,
+ )
+
+ # --- Step 2: Build strided mask for full generation ---
+ full_seq_len = full_sequences.shape[1]
+ model_dtype = next(model.parameters()).dtype
+
+ # Free generation-phase memory before training forward pass
+ torch.cuda.empty_cache()
+
+ attn_mask, pos_ids = self._build_strided_mask(
+ full_seq_len,
+ seq_len,
+ gen_len,
+ num_blocks,
+ B * n_samples,
+ device,
+ model_dtype,
+ anchor_offset=anchor_offset,
+ )
+
+ # --- Step 3: Forward pass through actor for log probs ---
+ # Memory optimization: process one sample at a time through the backbone
+ # to avoid B*N × S × H activation memory. For Llama-1B at S=3900, each
+ # sample's backbone forward takes ~8.7 GB with grad checkpointing.
+ # Processing B*N=4 at once would need ~35 GB → OOM.
+ # Instead, we accumulate per-token logprobs sample-by-sample.
+ gen_start = seq_len - 1 # shifted index where generated tokens start
+ compute_start = 0 if self.ebft_ce_coef > 0 else gen_start
+ BN = B * n_samples
+
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ # Navigate through PEFT wrapper to get backbone + lm_head
+ base_model = getattr(unwrapped_model, "model", unwrapped_model)
+ if hasattr(base_model, "model") and hasattr(base_model, "lm_head"):
+ backbone = base_model.model
+ lm_head = base_model.lm_head
+ else:
+ backbone = None
+
+ per_token_logps_list = []
+ shift_labels = full_sequences[:, 1:]
+
+ if backbone is not None:
+ # Process one sample at a time: backbone → chunked lm_head → logprobs
+ # This keeps peak memory at ~1 sample's activations instead of B*N.
+ for s_idx in range(BN):
+ seq_s = full_sequences[s_idx : s_idx + 1] # (1, full_seq_len)
+ # Handle attention mask format (BlockMask vs dense 4D)
+ if isinstance(attn_mask, torch.Tensor) and attn_mask.dim() == 4:
+ mask_s = attn_mask[s_idx : s_idx + 1]
+ else:
+ mask_s = attn_mask # BlockMask broadcasts over batch
+ pos_s = pos_ids[s_idx : s_idx + 1]
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ backbone_out = backbone(
+ seq_s,
+ attention_mask=mask_s,
+ position_ids=pos_s,
+ return_dict=True,
+ )
+ hidden_s = backbone_out.last_hidden_state # (1, full_seq_len, H)
+ labels_s = shift_labels[s_idx : s_idx + 1]
+
+ logps_s = torch.zeros(
+ 1,
+ hidden_s.shape[1] - 1,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ region_h = hidden_s[:, compute_start:-1, :]
+ region_l = labels_s[:, compute_start:]
+ chunk_size = 256
+ for i in range(0, region_h.shape[1], chunk_size):
+ h_chunk = region_h[:, i : i + chunk_size, :]
+ l_chunk = region_l[:, i : i + chunk_size]
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits_chunk = lm_head(h_chunk)
+ chunk_lp = F.log_softmax(logits_chunk.float(), dim=-1)
+ logps_s[
+ :, compute_start + i : compute_start + i + h_chunk.shape[1]
+ ] = chunk_lp.gather(-1, l_chunk.unsqueeze(-1)).squeeze(-1)
+ del logits_chunk, chunk_lp
+ per_token_logps_list.append(logps_s)
+ del hidden_s, backbone_out, region_h
+
+ per_token_logps = torch.cat(per_token_logps_list, dim=0)
+ else:
+ # Fallback: full forward (non-standard model architecture)
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ outputs = model(
+ full_sequences,
+ attention_mask=attn_mask,
+ position_ids=pos_ids,
+ return_dict=True,
+ )
+ logits = outputs.logits
+ per_token_logps = torch.zeros(
+ logits.shape[0],
+ logits.shape[1] - 1,
+ device=device,
+ dtype=torch.float32,
+ )
+ region_logits = logits[:, compute_start:-1, :]
+ region_labels = shift_labels[:, compute_start:]
+ chunk_size = 256
+ for i in range(0, region_logits.shape[1], chunk_size):
+ chunk_logits = region_logits[:, i : i + chunk_size, :]
+ chunk_labels = region_labels[:, i : i + chunk_size]
+ chunk_lp = F.log_softmax(chunk_logits.float(), dim=-1)
+ per_token_logps[
+ :, compute_start + i : compute_start + i + chunk_logits.shape[1]
+ ] = chunk_lp.gather(-1, chunk_labels.unsqueeze(-1)).squeeze(-1)
+ del logits, region_logits
+
+ action_mask = torch.zeros(
+ per_token_logps.shape, dtype=torch.bool, device=device
+ )
+ # Only mark actual generated tokens (not padding beyond num_blocks * gen_len)
+ gen_end = gen_start + num_blocks * gen_len
+ action_mask[:, gen_start:gen_end] = True
+
+ # --- Step 4: Extract features and compute rewards ---
+ with torch.no_grad():
+ block_rewards = self._compute_block_rewards(
+ full_sequences,
+ attn_mask,
+ pos_ids,
+ input_ids,
+ num_blocks,
+ B,
+ n_samples,
+ anchor_offset=anchor_offset,
+ )
+
+ del attn_mask, pos_ids
+ torch.cuda.empty_cache()
+
+ # --- Step 5: Compute advantages ---
+ advantages_per_block = self._compute_advantages(
+ block_rewards, B, n_samples, num_blocks
+ )
+
+ token_advantages = advantages_per_block.repeat_interleave(gen_len, dim=1)
+ full_advantages = torch.zeros_like(per_token_logps)
+ # Only fill actual generated region (not padding beyond num_blocks * gen_len)
+ adv_len = token_advantages.shape[1] # = num_blocks * gen_len
+ full_advantages[:, gen_start : gen_start + adv_len] = token_advantages
+
+ # --- Step 6: Compute loss ---
+ # RL loss: REINFORCE on generated tokens (needs grad through per_token_logps)
+ rl_loss_per_token = -per_token_logps * full_advantages.detach()
+ rl_loss = (
+ rl_loss_per_token * action_mask.float()
+ ).sum() / action_mask.float().sum().clamp(min=1)
+
+ # CE loss: For structured data, only compute on completion ground-truth tokens
+ # (labels != -100 in the original input). For unstructured data, compute on
+ # all non-action (prompt) tokens as before.
+ ce_loss = torch.tensor(0.0, device=device)
+ if self.ebft_ce_coef > 0:
+ if is_structured and "labels" in inputs:
+ labels = inputs["labels"].to(device) # (B, seq_len)
+ shifted_labels = labels[:, 1:] # (B, seq_len - 1)
+ ce_mask_base = shifted_labels != -100 # (B, seq_len - 1)
+ ce_mask_repeated = ce_mask_base.repeat_interleave(n_samples, dim=0)
+ ce_mask = torch.zeros(
+ per_token_logps.shape, dtype=torch.bool, device=device
+ )
+ ce_mask[:, : ce_mask_repeated.shape[1]] = ce_mask_repeated
+ ce_mask[:, gen_start:] = False
+ else:
+ ce_mask = ~action_mask
+ ce_loss = (
+ -per_token_logps * ce_mask.float()
+ ).sum() / ce_mask.float().sum().clamp(min=1)
+
+ loss = self.ebft_rl_coef * rl_loss + self.ebft_ce_coef * ce_loss
+
+ # --- Log metrics ---
+ if self.state.global_step % self.args.logging_steps == 0:
+ _alignment = getattr(self, "_last_alignment", 0.0)
+ _diversity = getattr(self, "_last_diversity", 0.0)
+ _cfm = getattr(self, "_last_cfm", 0.0)
+ _mean_reward = block_rewards.mean().item()
+ _adv_std = advantages_per_block.std().item()
+
+ log_dict = {
+ "ebft/rl_loss": rl_loss.item(),
+ "ebft/ce_loss": ce_loss.item(),
+ "ebft/cfm_loss": _cfm,
+ "ebft/mean_reward": _mean_reward,
+ "ebft/alignment": _alignment,
+ "ebft/diversity": _diversity,
+ "ebft/num_blocks": num_blocks,
+ "ebft/advantages_std": _adv_std,
+ }
+ if is_structured:
+ log_dict["ebft/anchor_offset"] = anchor_offset
+ self.log(log_dict)
+
+ # Human-readable summary with direction arrows:
+ # alignment (^ better) — cosine sim to GT features, range [-2, 2]
+ # diversity (v better) — pairwise sim penalty, lower = more diverse
+ # cfm_loss (v better) — ||E[phi(y_hat)] - phi(y)||^2
+ # reward (^ better) — alignment - diversity
+ LOG.info(
+ f"step {self.state.global_step} | "
+ f"align {_alignment:+.3f} ^ | "
+ f"divers {_diversity:+.3f} v | "
+ f"cfm {_cfm:.3f} v | "
+ f"reward {_mean_reward:+.3f} ^ | "
+ f"adv_std {_adv_std:.3f} | "
+ f"blocks {num_blocks}"
+ )
+
+ return (loss, outputs) if return_outputs else loss
+
+ @torch._dynamo.disable
+ @torch.no_grad()
+ def _generate_strided_blocks(
+ self, model, prompt_ids, num_blocks, anchor_offset=None
+ ):
+ """Generate tokens using strided block-parallel attention.
+
+ Uses eager attention (dense 4D masks) during generation to avoid dynamo
+ recompilation — each generation step has a different sequence length.
+ The training forward pass (fixed size) uses flex_attention when available.
+
+ Args:
+ anchor_offset: Position where anchors start. Defaults to context_length.
+ """
+ B, seq_len = prompt_ids.shape
+ gen_len = self.ebft_generate_max_len
+ stride = self.ebft_stride
+ if anchor_offset is None:
+ anchor_offset = self.ebft_context_length
+ temperature = self.ebft_temperature
+ top_p = self.ebft_top_p
+ device = prompt_ids.device
+ model_dtype = next(model.parameters()).dtype
+
+ full_sequence = prompt_ids.clone()
+
+ # Force eager attention during generation to avoid dynamo recompiles from:
+ # 1. Variable sequence lengths per gen step → size-mismatch recompiles
+ # 2. no_grad vs grad toggling → grad_mode recompiles
+ # Both cause dynamo to hit the recompile limit → unfused fallback → OOM
+ unwrapped = self.accelerator.unwrap_model(model)
+ with override_attn_implementation(unwrapped, "eager"):
+ for generation_step in range(gen_len):
+ cur_len = full_sequence.shape[1]
+
+ dense_mask, pos_ids = build_strided_dense_mask_and_positions(
+ full_sequence_length=cur_len,
+ prompt_length=seq_len,
+ context_length=anchor_offset,
+ generation_step=generation_step,
+ max_generation_length=gen_len,
+ stride=stride,
+ num_blocks=num_blocks,
+ device=device,
+ batch_size=B,
+ dtype=model_dtype,
+ )
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ output = model(
+ full_sequence,
+ attention_mask=dense_mask,
+ position_ids=pos_ids,
+ return_dict=True,
+ )
+ all_logits = output.logits
+
+ logit_positions = []
+ for block_idx in range(num_blocks):
+ if generation_step == 0:
+ # Last token of the context window predicts the first rollout token
+ pos = anchor_offset + block_idx * stride - 1
+ else:
+ pos = seq_len + (generation_step - 1) * num_blocks + block_idx
+ logit_positions.append(pos)
+
+ position_indices = torch.tensor(logit_positions, device=device)
+ block_logits = all_logits.index_select(1, position_indices)
+
+ if temperature > 0:
+ block_logits = block_logits / temperature
+ probs = torch.softmax(block_logits, dim=-1)
+
+ if top_p < 1.0:
+ sorted_probs, sorted_idx = torch.sort(
+ probs, descending=True, dim=-1
+ )
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
+ remove = cumulative > top_p
+ remove[..., 1:] = remove[..., :-1].clone()
+ remove[..., 0] = False
+ mask = torch.zeros_like(probs, dtype=torch.bool)
+ mask.scatter_(-1, sorted_idx, remove)
+ probs[mask] = 0
+ probs = probs / probs.sum(dim=-1, keepdim=True)
+
+ flat_probs = probs.view(-1, probs.shape[-1])
+ sampled = torch.multinomial(flat_probs, 1).squeeze(-1)
+ sampled = sampled.view(B, num_blocks)
+ else:
+ sampled = torch.argmax(block_logits, dim=-1)
+
+ full_sequence = torch.cat([full_sequence, sampled], dim=1)
+
+ return full_sequence
+
+ @torch._dynamo.disable
+ @torch.no_grad()
+ def _compute_block_rewards(
+ self,
+ full_sequences,
+ attn_mask,
+ pos_ids,
+ original_ids,
+ num_blocks,
+ batch_size,
+ n_samples,
+ anchor_offset=None,
+ ):
+ """Extract features and compute per-block rewards. Returns (B, N, NB).
+
+ Args:
+ anchor_offset: Position where anchors start. For structured data this
+ is after the prompt; for unstructured it equals context_length.
+ """
+ device = full_sequences.device
+ seq_len = original_ids.shape[1]
+ gen_len = self.ebft_generate_max_len
+ stride = self.ebft_stride
+ if anchor_offset is None:
+ anchor_offset = self.ebft_context_length
+
+ # Run feature network on its device WITH the strided attention mask.
+ # Without the strided mask, generated tokens see tokens from other blocks
+ # via default causal attention, corrupting the feature representations.
+ fd = self._feature_device
+ fn_seqs = full_sequences.to(fd)
+ fn_pos = pos_ids.to(fd)
+
+ # Determine which model to use for feature extraction
+ if self._share_feature_weights:
+ # Use actor's base weights with adapters disabled.
+ # Force eager attention to avoid grad_mode recompiles on the shared
+ # compiled flex_attention kernel (feature extraction is no_grad,
+ # training forward is with grad — each switch recompiles).
+ unwrapped_actor = self.accelerator.unwrap_model(self.model)
+ feat_model = unwrapped_actor
+ feature_ctx = unwrapped_actor.disable_adapter()
+ # Use SDPA (flash attention) instead of flex to avoid grad_mode recompiles
+ # on the shared compiled flex kernel. SDPA is fused (no score matrix
+ # materialization) and needs no compilation — ideal for no_grad feature extraction.
+ attn_ctx = override_attn_implementation(unwrapped_actor, "sdpa")
+ use_flex_for_features = False
+ else:
+ feat_model = self.feature_network
+ feature_ctx = contextlib.nullcontext()
+ attn_ctx = contextlib.nullcontext()
+ use_flex_for_features = self._feature_use_flex
+
+ # Build strided mask — flex block mask if available, else dense 4D
+ if use_flex_for_features:
+ fn_attn_mask = create_strided_block_mask(
+ prompt_length=seq_len,
+ context_length=anchor_offset,
+ max_generation_length=gen_len,
+ stride=stride,
+ num_blocks=num_blocks,
+ full_sequence_length=full_sequences.shape[1],
+ batch_size=full_sequences.shape[0],
+ num_heads=feat_model.config.num_attention_heads,
+ device=fd,
+ )
+ else:
+ fn_attn_mask, _ = build_strided_dense_mask_and_positions(
+ full_sequence_length=full_sequences.shape[1],
+ prompt_length=seq_len,
+ context_length=anchor_offset,
+ generation_step=gen_len,
+ max_generation_length=gen_len,
+ stride=stride,
+ num_blocks=num_blocks,
+ device=fd,
+ batch_size=full_sequences.shape[0],
+ dtype=torch.bfloat16,
+ )
+
+ with (
+ feature_ctx,
+ attn_ctx,
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16),
+ ):
+ was_training = feat_model.training
+ feat_model.eval()
+ fn_outputs = feat_model(
+ fn_seqs,
+ attention_mask=fn_attn_mask,
+ position_ids=fn_pos,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ if was_training:
+ feat_model.train()
+ hidden_states_cpu = [
+ fn_outputs.hidden_states[idx].to(device)
+ for idx in self.feature_layer_indices
+ ]
+ del fn_outputs, fn_seqs, fn_pos, fn_attn_mask
+
+ # Normalize each layer's hidden states separately (like the reference critic),
+ # then concatenate. This prevents one dominant layer from suppressing others.
+ normalized_layers = [F.normalize(h, p=2, dim=-1) for h in hidden_states_cpu]
+ features = torch.cat(normalized_layers, dim=-1).to(device)
+ del hidden_states_cpu, normalized_layers
+
+ # Ground-truth features start from anchor_offset (not ctx_len) so they
+ # align with where anchors are actually placed.
+ gt_features = features[:, anchor_offset:seq_len, :]
+ # Only take actual generated tokens (exclude padding beyond num_blocks * gen_len)
+ gen_features = features[:, seq_len : seq_len + num_blocks * gen_len, :]
+
+ gt_block_features = gt_features.unfold(1, gen_len, stride).permute(0, 1, 3, 2)
+ gen_block_features = gen_features.reshape(
+ batch_size * n_samples, gen_len, num_blocks, -1
+ ).transpose(1, 2)
+
+ if self.ebft_embed_method == "mean_pooling":
+ gt_emb = gt_block_features.mean(dim=2)
+ gen_emb = gen_block_features.mean(dim=2)
+ else: # last_token
+ gt_emb = gt_block_features[:, :, -1, :]
+ gen_emb = gen_block_features[:, :, -1, :]
+
+ gt_emb = gt_emb.view(batch_size, n_samples, num_blocks, -1)
+ gen_emb = gen_emb.view(batch_size, n_samples, num_blocks, -1)
+
+ if self.ebft_use_whitening:
+ whitened_gen, whitened_gt = [], []
+ for b in range(batch_size):
+ for nb in range(num_blocks):
+ w_gen, w_gt = whiten_embeddings_batched(
+ gen_emb[b, :, nb, :],
+ gt_emb[b, :, nb, :],
+ )
+ whitened_gen.append(w_gen)
+ whitened_gt.append(w_gt)
+ gen_emb = (
+ torch.stack(whitened_gen)
+ .view(batch_size, num_blocks, n_samples, -1)
+ .transpose(1, 2)
+ )
+ gt_emb = (
+ torch.stack(whitened_gt)
+ .view(batch_size, num_blocks, n_samples, -1)
+ .transpose(1, 2)
+ )
+
+ alignment = F.cosine_similarity(gen_emb, gt_emb, dim=-1)
+
+ # Batched diversity: reshape to avoid per-block Python loop
+ diversity = torch.zeros_like(alignment)
+ if n_samples > 1:
+ # (B, N, NB, D) → (B*NB, N, D) for a single batched bmm
+ gen_for_div = gen_emb.permute(0, 2, 1, 3).reshape(
+ batch_size * num_blocks, n_samples, -1
+ )
+ sims = torch.bmm(gen_for_div, gen_for_div.transpose(1, 2)) # (B*NB, N, N)
+ eye = torch.eye(n_samples, device=device, dtype=torch.bool)
+ sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
+ div_flat = sims.sum(dim=-1) / (n_samples - 1) # (B*NB, N)
+ diversity = div_flat.view(batch_size, num_blocks, n_samples).permute(
+ 0, 2, 1
+ ) # (B, N, NB)
+
+ # Scale by 2 per paper equation (7):
+ # r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
+ alignment = alignment * 2
+ diversity = diversity * 2
+
+ # Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
+ # Mean generated embedding per prompt, squared distance to GT
+ mean_gen_emb = gen_emb.mean(dim=1, keepdim=True) # (B, 1, NB, D)
+ gt_for_cfm = gt_emb[:, 0:1, :, :] # (B, 1, NB, D) — one GT per prompt
+ cfm_loss = ((mean_gen_emb - gt_for_cfm) ** 2).sum(dim=-1).mean()
+
+ # Store for logging
+ self._last_alignment = alignment.mean().item()
+ self._last_diversity = diversity.mean().item()
+ self._last_cfm = cfm_loss.item()
+
+ return (
+ self.ebft_alignment_coef * alignment - self.ebft_diversity_coef * diversity
+ )
+
+ def _compute_advantages(self, rewards, batch_size, n_samples, num_blocks):
+ """Compute RLOO advantages. rewards: (B, N, NB) → (B*N, NB)."""
+ if self.ebft_advantage_estimator == "rloo" and n_samples > 1:
+ total = rewards.sum(dim=1, keepdim=True)
+ baseline = (total - rewards) / (n_samples - 1)
+ advantages = rewards - baseline
+ elif self.ebft_advantage_estimator == "group_norm" and n_samples > 1:
+ mean = rewards.mean(dim=1, keepdim=True)
+ std = rewards.std(dim=1, keepdim=True) + 1e-8
+ advantages = (rewards - mean) / std
+ else:
+ advantages = rewards
+ return advantages.view(batch_size * n_samples, num_blocks)
diff --git a/src/axolotl/core/trainers/ebft/trainer.py b/src/axolotl/core/trainers/ebft/trainer.py
new file mode 100644
index 000000000..1c27fa91f
--- /dev/null
+++ b/src/axolotl/core/trainers/ebft/trainer.py
@@ -0,0 +1,531 @@
+"""
+EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
+
+Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
+the standard GRPO reward function interface.
+
+Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
+ (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
+"""
+
+import contextlib
+import copy
+from typing import TYPE_CHECKING, Any
+
+import torch
+from datasets import Dataset, IterableDataset
+from peft import PeftModel
+from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
+
+from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
+from axolotl.core.trainers.ebft.rewards import (
+ apply_embed_method,
+ extract_hidden_states,
+ get_alignment_rewards,
+ get_diversity_rewards,
+ whiten_embeddings_batched,
+)
+from axolotl.core.trainers.grpo.trainer import (
+ AxolotlAsyncGRPOTrainer,
+ AxolotlGRPOTrainer,
+)
+from axolotl.utils.logging import get_logger
+
+if TYPE_CHECKING:
+ from collections import defaultdict
+
+ from accelerate import Accelerator
+ from trl.generation.vllm_generation import VLLMGeneration
+
+LOG = get_logger(__name__)
+
+
+class EBFTMixin:
+ """
+ Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
+
+ Provides:
+ - Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
+ - _feature_matching_reward() callable for GRPO reward function interface
+ - _sequential_rollout() for multi-turn conversations
+ """
+
+ # Type stubs for attributes provided by the composed GRPOTrainer base class.
+ # These are not defined here but accessed via cooperative multiple inheritance.
+ if TYPE_CHECKING:
+ accelerator: Accelerator
+ model: PreTrainedModel
+ args: AxolotlEBFTConfig
+ processing_class: PreTrainedTokenizerBase
+ num_generations: int
+ vllm_generation: VLLMGeneration
+ _metrics: defaultdict
+
+ _tag_names = ["trl", "ebft", "axolotl"]
+
+ def __init__(
+ self,
+ model: str | PreTrainedModel,
+ args: AxolotlEBFTConfig | None = None,
+ train_dataset: Dataset | IterableDataset | None = None,
+ eval_dataset: Dataset
+ | IterableDataset
+ | dict[str, Dataset | IterableDataset]
+ | None = None,
+ processing_class: PreTrainedTokenizerBase | None = None,
+ callbacks: list[TrainerCallback] | None = None,
+ optimizers: tuple[
+ torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
+ ] = (None, None),
+ peft_config: Any | None = None,
+ ):
+ # Pass our feature-matching reward function to GRPOTrainer
+ # It will be called with (prompts, completions, **kwargs) where
+ # kwargs includes all extra dataset fields like "ground_truth"
+ super().__init__( # type: ignore[call-arg]
+ model=model,
+ reward_funcs=[self._feature_matching_reward],
+ args=args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ peft_config=peft_config,
+ )
+ assert args is not None
+
+ # --- Feature network setup ---
+ unwrapped = self.accelerator.unwrap_model(self.model)
+ # Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
+ self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
+ unwrapped, "disable_adapter"
+ )
+
+ if self._share_feature_weights:
+ # Share weights: use actor's base model with adapters disabled.
+ # Saves a full model copy (~8 GB for 4B model).
+ self.feature_network = None
+ param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
+ LOG.info(
+ f"EBFT feature network shares actor weights (PEFT disable_adapter). "
+ f"Saving ~{param_gb:.1f} GB"
+ )
+ else:
+ LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
+ self.feature_network = copy.deepcopy(unwrapped)
+ for param in self.feature_network.parameters():
+ param.requires_grad = False
+ self.feature_network.eval()
+
+ # Compute layer indices from fractional depths
+ # Handle VLM models where num_hidden_layers is on text_config
+ config = unwrapped.config
+ if hasattr(config, "text_config") and hasattr(
+ config.text_config, "num_hidden_layers"
+ ):
+ config = config.text_config
+ num_layers = config.num_hidden_layers
+ self.feature_layer_indices = [
+ int(frac * num_layers) for frac in args.ebft_feature_layers
+ ]
+ LOG.info(
+ f"EBFT feature extraction from layers {self.feature_layer_indices} "
+ f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
+ )
+ if args.ebft_adaptive_max_tokens:
+ LOG.info(
+ f"EBFT adaptive max_tokens enabled "
+ f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
+ )
+
+ _adaptive_max_lock = None # initialized lazily
+
+ def _generate_only(self, inputs, rank0_only=False):
+ """Override to set per-batch max_tokens based on ground-truth length.
+
+ Uses a lock to prevent race conditions in async mode where concurrent
+ BG threads could interleave mutations of max_completion_length.
+ """
+ import threading
+
+ args = self.args
+ if (
+ args.ebft_adaptive_max_tokens
+ and hasattr(self, "vllm_generation")
+ and inputs
+ ):
+ gt_texts = [
+ x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
+ ]
+ if gt_texts:
+ gt_token_counts = [
+ len(self.processing_class.encode(gt, add_special_tokens=False))
+ for gt in gt_texts
+ ]
+ multiplier = args.ebft_gt_length_multiplier
+ max_completion = self.vllm_generation.max_completion_length
+ adaptive_max = max(
+ min(int(c * multiplier), max_completion) for c in gt_token_counts
+ )
+ adaptive_max = max(adaptive_max, 64)
+
+ if self._adaptive_max_lock is None:
+ self._adaptive_max_lock = threading.Lock()
+ with self._adaptive_max_lock:
+ original = self.vllm_generation.max_completion_length
+ self.vllm_generation.max_completion_length = adaptive_max
+ try:
+ return super()._generate_only(inputs, rank0_only)
+ finally:
+ self.vllm_generation.max_completion_length = original
+
+ return super()._generate_only(inputs, rank0_only)
+
+ @torch.no_grad()
+ def _feature_matching_reward(
+ self,
+ prompts: list,
+ completions: list,
+ ground_truth: list[str] | None = None,
+ remaining_turns: list | None = None,
+ **kwargs,
+ ) -> list[float]:
+ """
+ Compute feature-matching rewards for generated completions.
+
+ This is called by GRPOTrainer's _generate_and_score_completions()
+ as a standard reward function. The `ground_truth` field comes from
+ the dataset via reward_kwargs.
+
+ For multi-turn conversations, `remaining_turns` contains the subsequent
+ user/assistant turn pairs. When present, we do sequential rollouts:
+ generate each assistant turn conditioned on history + previous generations,
+ then compute feature-matching rewards on the full generated conversation.
+
+ Args:
+ prompts: List of prompt strings/messages
+ completions: List of generated completion strings
+ ground_truth: List of reference completion strings (from dataset)
+ remaining_turns: List of remaining conversation turns after the
+ first assistant turn (for multi-turn rollouts)
+
+ Returns:
+ List of scalar rewards, one per completion
+ """
+ if ground_truth is None:
+ LOG.warning("No ground_truth field in dataset — using zero rewards")
+ return [0.0] * len(prompts)
+
+ device = self.accelerator.device
+ args = self.args
+ num_gens = self.num_generations
+
+ # --- Multi-turn sequential rollout ---
+ # If remaining_turns is provided, generate subsequent assistant turns
+ # by calling vLLM for each turn, building up the full conversation.
+ if remaining_turns is not None and hasattr(self, "vllm_generation"):
+ completions = self._sequential_rollout(
+ prompts, completions, remaining_turns, num_gens
+ )
+
+ # --- Tokenize generated sequences: prompt + completion ---
+ gen_texts = []
+ gen_prompt_texts = []
+ for p, c in zip(prompts, completions, strict=True):
+ if isinstance(p, list):
+ prompt_text = self.processing_class.apply_chat_template(
+ p, tokenize=False, add_generation_prompt=True
+ )
+ else:
+ prompt_text = p
+ if isinstance(c, list):
+ comp_text = c[0].get("content", "") if c else ""
+ else:
+ comp_text = c
+ gen_texts.append(prompt_text + comp_text)
+ gen_prompt_texts.append(prompt_text)
+
+ gen_encoded = self.processing_class(
+ text=gen_texts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=getattr(self.args, "max_length", None)
+ or getattr(self.args, "max_seq_length", None)
+ or 2048,
+ add_special_tokens=False,
+ )
+ gen_ids = gen_encoded["input_ids"].to(device)
+ gen_mask = gen_encoded["attention_mask"].to(device)
+
+ # Compute prompt lengths for completion_mean pooling
+ gen_prompt_lengths = torch.tensor(
+ [
+ len(self.processing_class.encode(pt, add_special_tokens=False))
+ for pt in gen_prompt_texts
+ ],
+ device=device,
+ )
+
+ # --- Tokenize ground-truth sequences: prompt + ground_truth ---
+ # For multi-turn (remaining_turns present), render the full GT conversation
+ # through the chat template to preserve role markers between turns.
+ gt_texts = []
+ gt_prompt_texts = []
+ for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
+ if i % num_gens != 0:
+ continue # Only need one GT per prompt group
+ if isinstance(p, list):
+ prompt_text = self.processing_class.apply_chat_template(
+ p, tokenize=False, add_generation_prompt=True
+ )
+ # Multi-turn: build full GT conversation with remaining turns
+ if remaining_turns is not None:
+ prompt_idx = i // num_gens
+ turns = (
+ remaining_turns[prompt_idx]
+ if prompt_idx < len(remaining_turns)
+ else []
+ )
+ if turns:
+ gt_conv = list(p) + [{"role": "assistant", "content": gt}]
+ gt_conv.extend(turns)
+ full_gt_text = self.processing_class.apply_chat_template(
+ gt_conv, tokenize=False, add_generation_prompt=False
+ )
+ gt_texts.append(full_gt_text)
+ gt_prompt_texts.append(prompt_text)
+ continue
+ else:
+ prompt_text = p
+ gt_texts.append(prompt_text + gt)
+ gt_prompt_texts.append(prompt_text)
+
+ gt_encoded = self.processing_class(
+ text=gt_texts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=getattr(self.args, "max_length", None)
+ or getattr(self.args, "max_seq_length", None)
+ or 2048,
+ add_special_tokens=False,
+ )
+ gt_ids = gt_encoded["input_ids"].to(device)
+ gt_mask = gt_encoded["attention_mask"].to(device)
+
+ gt_prompt_lengths = torch.tensor(
+ [
+ len(self.processing_class.encode(pt, add_special_tokens=False))
+ for pt in gt_prompt_texts
+ ],
+ device=device,
+ )
+
+ # --- Extract features from frozen feature network ---
+ # INVARIANT: disable_adapter() yields the unmodified base weights because
+ # _sync_peft_weights_no_merge and _sync_lora_adapter never call
+ # merge_adapter() — they compute merged weights as new tensors or save
+ # the adapter to filesystem. Base weights are never modified in-place.
+ if self._share_feature_weights:
+ unwrapped = self.accelerator.unwrap_model(self.model)
+ feature_ctx = unwrapped.disable_adapter()
+ else:
+ unwrapped = self.feature_network
+ feature_ctx = contextlib.nullcontext()
+
+ with feature_ctx:
+ was_training = unwrapped.training
+ unwrapped.eval()
+ gen_hidden = extract_hidden_states(
+ unwrapped, gen_ids, gen_mask, self.feature_layer_indices
+ )
+ gt_hidden = extract_hidden_states(
+ unwrapped, gt_ids, gt_mask, self.feature_layer_indices
+ )
+ if was_training:
+ unwrapped.train()
+
+ # --- Pool to sequence-level embeddings ---
+ gen_emb = apply_embed_method(
+ gen_hidden,
+ args.ebft_embed_method,
+ gen_mask,
+ prompt_lengths=gen_prompt_lengths,
+ )
+ gt_emb = apply_embed_method(
+ gt_hidden,
+ args.ebft_embed_method,
+ gt_mask,
+ prompt_lengths=gt_prompt_lengths,
+ )
+
+ # --- Optional whitening ---
+ batch_size = gen_emb.shape[0]
+ if args.ebft_use_whitening and batch_size > 1:
+ num_prompts = batch_size // num_gens
+ gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
+ whitened_gen_list = []
+ whitened_gt_list = []
+ for i in range(num_prompts):
+ w_gen, w_gt = whiten_embeddings_batched(
+ gen_reshaped[i], gt_emb[i : i + 1]
+ )
+ whitened_gen_list.append(w_gen)
+ whitened_gt_list.append(w_gt)
+ gen_emb = torch.cat(whitened_gen_list, dim=0)
+ gt_emb = torch.cat(whitened_gt_list, dim=0)
+ else:
+ gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
+ gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
+
+ # Repeat gt_emb: each GT repeated num_generations times
+ gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
+
+ # --- Compute rewards ---
+ alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
+ diversity = get_diversity_rewards(gen_emb, num_gens)
+
+ # Scale by 2 per paper equation (7):
+ # r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
+ alignment = alignment * 2
+ diversity = diversity * 2
+
+ rewards = (
+ args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
+ )
+
+ # Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
+ gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
+ mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
+ cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
+
+ # Log feature-matching metrics to console and wandb
+ _align = alignment.mean().item()
+ _divers = diversity.mean().item()
+ _reward = rewards.mean().item()
+ _cfm = cfm_loss.item()
+
+ LOG.info(
+ f"ebft reward | "
+ f"align {_align:+.3f} ^ | "
+ f"divers {_divers:+.3f} v | "
+ f"cfm {_cfm:.3f} v | "
+ f"reward {_reward:+.3f} ^"
+ )
+
+ # Log to wandb via trainer's _metrics (picked up by GRPO's logging)
+ mode = "train" if self.model.training else "eval"
+ if hasattr(self, "_metrics"):
+ self._metrics[mode]["ebft/alignment"].append(_align)
+ self._metrics[mode]["ebft/diversity"].append(_divers)
+ self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
+ self._metrics[mode]["ebft/reward"].append(_reward)
+
+ return rewards.cpu().tolist()
+
+ @torch.no_grad()
+ def _sequential_rollout(
+ self,
+ prompts: list,
+ first_completions: list,
+ remaining_turns: list,
+ num_gens: int,
+ ) -> list:
+ """
+ Extend single-turn completions into multi-turn conversations.
+
+ For each prompt group, takes the first generated assistant turn and
+ sequentially generates subsequent assistant turns by calling vLLM,
+ building up a full multi-turn conversation.
+
+ Args:
+ prompts: List of prompt message lists (repeated num_gens times)
+ first_completions: List of generated first-turn completions
+ remaining_turns: List of remaining turn pairs after first assistant turn.
+ Each element is a list of dicts: [{"role": "user", "content": "..."},
+ {"role": "assistant", "content": "...GT..."}]
+ num_gens: Number of generations per prompt
+
+ Returns:
+ Extended completions incorporating all generated turns
+ """
+ vllm_client = self.vllm_generation.vllm_client
+ max_tokens = getattr(self.args, "max_completion_length", 256)
+ temperature = getattr(self.args, "temperature", 0.7)
+ gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
+
+ extended_completions = []
+
+ for idx in range(len(prompts)):
+ prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
+ first_comp = first_completions[idx]
+
+ # Extract first completion text
+ if isinstance(first_comp, list):
+ first_text = first_comp[0].get("content", "") if first_comp else ""
+ else:
+ first_text = first_comp
+
+ # Get remaining turns for this prompt (same for all num_gens copies)
+ prompt_idx = idx // num_gens
+ turns = (
+ remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
+ )
+
+ if not turns:
+ extended_completions.append(first_text)
+ continue
+
+ # Build conversation with generated first turn
+ conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
+
+ # Generate subsequent turns
+ for turn in turns:
+ if turn["role"] == "user":
+ conv.append(turn)
+ elif turn["role"] == "assistant":
+ try:
+ result = vllm_client.chat(
+ messages=[conv],
+ n=1,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ generation_kwargs=gen_kwargs,
+ )
+ gen_ids = result.get("completion_ids", [[]])[0]
+ gen_text = self.processing_class.decode(
+ gen_ids, skip_special_tokens=True
+ )
+ except Exception as e:
+ LOG.warning(f"Multi-turn rollout generation failed: {e}")
+ gen_text = ""
+
+ conv.append({"role": "assistant", "content": gen_text})
+
+ # Render full conversation through chat template, then extract
+ # everything after the original prompt as the "completion" text.
+ # This preserves role markers and formatting between turns.
+ full_rendered = self.processing_class.apply_chat_template(
+ conv, tokenize=False, add_generation_prompt=False
+ )
+ prompt_rendered = self.processing_class.apply_chat_template(
+ prompt_msgs, tokenize=False, add_generation_prompt=True
+ )
+ completion_text = full_rendered[len(prompt_rendered) :]
+ extended_completions.append(completion_text)
+
+ return extended_completions
+
+
+class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
+ """EBFT trainer using synchronous GRPO (standard vLLM generation)."""
+
+ pass
+
+
+class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
+ """EBFT trainer using async GRPO (prefetches next batch during training)."""
+
+ pass
diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py
index acfd02909..2b8cda6d8 100644
--- a/src/axolotl/core/trainers/grpo/async_trainer.py
+++ b/src/axolotl/core/trainers/grpo/async_trainer.py
@@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer):
"""
def __init__(self, *args, **kwargs):
- # When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.
- # The communicator is not needed because weight sync happens via filesystem + HTTP,
- # and it fails when vLLM and a trainer rank share the same CUDA device.
+ # Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
+ # merged weight sync. NCCL is only needed for the standard update_named_param
+ # path which broadcasts tensors through the communicator.
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
- if training_args is not None and getattr(
- training_args, "vllm_lora_sync", False
- ):
+ _skip_nccl = False
+ if training_args is not None:
+ if getattr(training_args, "vllm_lora_sync", False):
+ _skip_nccl = True # LoRA sync uses filesystem + HTTP
+ elif getattr(training_args, "async_prefetch", False):
+ # Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
+ # init_communicator allocates device tensors on rank 0 only, which
+ # causes DDP to see different param counts across ranks.
+ # The communicator is initialized lazily on first weight sync instead.
+ _skip_nccl = True
+ if _skip_nccl:
from trl.generation.vllm_generation import VLLMGeneration
_orig_init_vllm = VLLMGeneration._init_vllm
@@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
VLLMGeneration._init_vllm = _init_vllm_no_communicator
- super().__init__(*args, **kwargs)
+ try:
+ super().__init__(*args, **kwargs)
+ finally:
+ # Restore original _init_vllm so other trainers aren't affected
+ if _skip_nccl:
+ VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
# FP8 models: zero out the pad token embedding so that padding
# positions have zero hidden states throughout the network.
@@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._executor = None
def _submit_generation(self):
- """Submit the next background generation job."""
+ """Submit the next background generation job.
+
+ With multi-process (DDP/FSDP), only rank 0 generates to avoid
+ cross-rank NCCL collectives from background threads. Non-rank-0
+ processes enqueue a sentinel ``None`` that is replaced by a
+ broadcast in ``_prepare_inputs_legacy_async``.
+ """
+ rank0_only = self.accelerator.num_processes > 1
+ if rank0_only and not self.accelerator.is_main_process:
+ # Non-rank-0: nothing to generate; enqueue a resolved None future
+ f: concurrent.futures.Future = concurrent.futures.Future()
+ f.set_result(None)
+ self._async_queue.put(f)
+ return
batch = next(self._prompt_iter)
- future = self._executor.submit(self._generate_only, batch)
+ future = self._executor.submit(self._generate_only, batch, rank0_only)
self._async_queue.put(future)
+ # ------------------------------------------------------------------
+ # Broadcast rollout (legacy async, multi-process)
+ # ------------------------------------------------------------------
+
+ def _broadcast_rollout(self, rollout: dict | None) -> dict:
+ """Broadcast a rank0-only rollout dict to all ranks (main thread).
+
+ Rank 0 has the full rollout dict from ``_generate_only``; other ranks
+ have ``None``. After broadcast, tensors are moved to each rank's
+ local device.
+ """
+ import torch.distributed as dist
+
+ obj_list = [rollout if self.accelerator.is_main_process else None]
+ dist.broadcast_object_list(obj_list, src=0)
+ rollout = obj_list[0]
+ assert rollout is not None, "broadcast_object_list failed to deliver rollout"
+
+ # Move tensors to local device (broadcast deserializes to CPU)
+ device = self.accelerator.device
+ for key, val in rollout.items():
+ if isinstance(val, torch.Tensor) and val.device != device:
+ rollout[key] = val.to(device)
+
+ return rollout
+
# ------------------------------------------------------------------
# Weight sync
# ------------------------------------------------------------------
@@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer):
for Float8), and also safe for concurrent use since it never modifies base
weights in-place.
"""
- model = self.vllm_generation.model
accelerator = self.vllm_generation.accelerator
- vllm_client = self.vllm_generation.vllm_client
- fix_name = self.vllm_generation._fix_param_name_to_vllm
-
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
return
+ # In multi-GPU async mode, we skip NCCL communicator init to avoid
+ # DDP param count mismatch and NCCL device conflicts. Weight sync
+ # uses the HTTP-only fallback in batch_update_named_params instead.
+
+ model = self.vllm_generation.model
+ vllm_client = self.vllm_generation.vllm_client
+ fix_name = self.vllm_generation._fix_param_name_to_vllm
+
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
lora_info = {}
for mod_name, module in model.base_model.model.named_modules():
@@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
weight_name = pname.replace(".weight_scale_inv", ".weight")
scale_inv_lookup[weight_name] = pparam.data
- # Iterate all parameters, computing merged weights for LoRA layers.
- # Skip LoRA-specific params and FP8 scale params (scales will be
- # recomputed by vLLM when it receives the merged bf16 weight).
+ # Only sync parameters that have LoRA modifications — skip unchanged
+ # base weights to avoid OOM on the vLLM GPU from allocating the entire
+ # model's worth of NCCL receive buffers.
params_to_sync = []
+ compute_dtype = torch.bfloat16
for name, param in model.named_parameters():
vllm_name = name.removeprefix("base_model.model.").replace(
".base_layer", ""
@@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer):
continue
if "original_module" in vllm_name:
continue
- # Skip FP8 quantization scale parameters - they are recomputed
- # on the vLLM side when we update the weight itself
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
continue
+ if not vllm_name.endswith(".weight"):
+ continue
+ # fix_name strips modules_to_save.default. prefix
+ raw_mod_path = vllm_name[: -len(".weight")]
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
+ mod_path = vllm_name[: -len(".weight")]
+
+ # Sync weights that have LoRA adapters OR are modules_to_save
+ is_lora = mod_path in lora_info
+ is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
+ if not is_lora and not is_modules_to_save:
+ continue
data = param.data
- compute_dtype = torch.bfloat16
- if vllm_name.endswith(".weight"):
- # Dequantize FP8 weights before merging
- if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
- scale_inv = scale_inv_lookup[name]
- # Block dequantization: weight * scale_inv (with broadcasting)
- fp8_bf16 = data.to(compute_dtype)
- if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
- # Block-quantized: scale_inv shape (rows/block, cols/block)
- sr, sc = scale_inv.shape
- br = fp8_bf16.shape[0] // sr # block height
- bc = fp8_bf16.shape[1] // sc # block width
- # Reshape → multiply by block scale → reshape back
- data = (
- fp8_bf16.reshape(sr, br, sc, bc)
- * scale_inv[:, None, :, None].to(compute_dtype)
- ).reshape(fp8_bf16.shape)
- elif scale_inv.dim() <= 1:
- # Per-tensor or per-channel scale
- data = fp8_bf16 * scale_inv.to(compute_dtype)
- else:
- data = fp8_bf16
- elif data.dtype == torch.float8_e4m3fn:
- # FP8 but no scale found - just cast (lossy)
- data = data.to(compute_dtype)
+ # Dequantize FP8 weights before merging
+ if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
+ scale_inv = scale_inv_lookup[name]
+ fp8_bf16 = data.to(compute_dtype)
+ if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
+ sr, sc = scale_inv.shape
+ br = fp8_bf16.shape[0] // sr
+ bc = fp8_bf16.shape[1] // sc
+ data = (
+ fp8_bf16.reshape(sr, br, sc, bc)
+ * scale_inv[:, None, :, None].to(compute_dtype)
+ ).reshape(fp8_bf16.shape)
+ elif scale_inv.dim() <= 1:
+ data = fp8_bf16 * scale_inv.to(compute_dtype)
+ else:
+ data = fp8_bf16
+ elif data.dtype == torch.float8_e4m3fn:
+ data = data.to(compute_dtype)
- mod_path = vllm_name[: -len(".weight")]
- if mod_path in lora_info:
- A, B, s = lora_info[mod_path]
- merged = data.to(compute_dtype) + s * (
- B.to(compute_dtype) @ A.to(compute_dtype)
- )
- data = merged
+ if is_lora:
+ A, B, s = lora_info[mod_path]
+ merged = data.to(compute_dtype) + s * (
+ B.to(compute_dtype) @ A.to(compute_dtype)
+ )
+ params_to_sync.append((vllm_name, merged))
+ else:
+ # modules_to_save: send raw weight (no LoRA merge needed)
+ params_to_sync.append((vllm_name, data.to(compute_dtype)))
- params_to_sync.append((vllm_name, data))
-
- # Batch sync all params in one HTTP+NCCL call (vs individual calls)
+ # Batch sync only LoRA-modified params via HTTP+NCCL
if params_to_sync:
+ sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
+ logger.info(
+ f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
+ )
vllm_client.batch_update_named_params(params_to_sync)
# Reset prefix cache after weight update
@@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
vllm_client = self.vllm_generation.vllm_client
url = f"{vllm_client.base_url}/set_lora_adapter/"
+ sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
response = requests.post(
url,
json={
@@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
"lora_int_id": self._lora_sync_version,
"lora_path": adapter_path,
},
- timeout=30,
+ timeout=sync_timeout,
)
if response.status_code != 200:
logger.warning(
@@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
step = self.state.global_step
interval = self.args.vllm_sync_interval
if step != self._last_synced_step and step % interval == 0:
+ if step == 0:
+ logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
+ self._last_synced_step = step
+ return
if getattr(self.args, "vllm_lora_sync", False):
- if step == 0:
- logger.info("Skipping LoRA sync at step 0 (no training yet)")
- self._last_synced_step = step
- return
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly
self._sync_lora_adapter()
else:
@@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Background-thread generation (no scoring)
# ------------------------------------------------------------------
- def _generate_single_turn(self, prompts, **kwargs):
+ def _generate_single_turn(self, prompts, *args, **kwargs):
"""Override to prevent weight sync from background thread and to use
no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
is_bg = threading.current_thread() is not threading.main_thread()
@@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._patched_sync_weights = True
try:
- return super()._generate_single_turn(prompts, **kwargs)
+ return super()._generate_single_turn(prompts, *args, **kwargs)
finally:
if saved_step is not None:
self._last_loaded_step = saved_step
@@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
output = vg.vllm_client.chat(
messages=unique_prompts,
**sampling_params,
- chat_template_kwargs=vg.chat_template_kwargs,
- tools=vg.tools,
- chat_template=vg.chat_template,
+ chat_template_kwargs=self.chat_template_kwargs,
+ tools=self.tools,
+ chat_template=getattr(self, "chat_template", None),
)
else:
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
@@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
logps_diff = per_token_logps_diff
is_ratio = torch.exp(logps_diff)
+ is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
if is_mode in ("sequence_truncate", "token_truncate"):
- is_ratio = torch.clamp(is_ratio, max=is_cap)
+ is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
+ is_ratio = is_ratio.clamp(min=is_floor)
data["importance_sampling_ratio"] = is_ratio
# --- Collect rewards (launched before logprobs, should be done) ---
@@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
seq_is = is_mode in ("sequence_mask", "sequence_truncate")
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
is_ratio = torch.exp(logps_diff)
+ # Symmetric floor clamp (matches non-streaming path at line ~1651)
+ is_floor = 1.0 / is_cap
if is_mode in ("sequence_truncate", "token_truncate"):
- is_ratio = torch.clamp(is_ratio, max=is_cap)
+ is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
elif is_mode in ("sequence_mask", "token_mask"):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
+ is_ratio = is_ratio.clamp(min=is_floor)
if "importance_sampling_ratio" not in data:
total = len(data["prompt_ids"])
shape = (total, 1) if seq_is else (total, is_ratio.size(1))
@@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
rollout = future.result()
self._submit_generation()
+ # With multi-process, only rank 0 generated. Broadcast to all ranks.
+ if self.accelerator.num_processes > 1:
+ rollout = self._broadcast_rollout(rollout)
+
if self.args.streaming_partial_batch:
micro_batches = self._score_streaming(rollout)
else:
diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py
index 5f5ff3400..61464ee7d 100644
--- a/src/axolotl/integrations/diffusion/callbacks.py
+++ b/src/axolotl/integrations/diffusion/callbacks.py
@@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback):
logger.info("=" * 60)
if self.trainer.axolotl_cfg.use_wandb:
- if wandb.run is not None:
- wandb.log(
+ if wandb.run is not None: # type: ignore[attr-defined]
+ wandb.log( # type: ignore[attr-defined]
{
- "generated_samples": wandb.Table(
+ "generated_samples": wandb.Table( # type: ignore[attr-defined]
columns=[
"step",
"original",
diff --git a/src/axolotl/monkeypatch/trainer/trl_vllm.py b/src/axolotl/monkeypatch/trainer/trl_vllm.py
index a3296df61..e3f57ccf5 100644
--- a/src/axolotl/monkeypatch/trainer/trl_vllm.py
+++ b/src/axolotl/monkeypatch/trainer/trl_vllm.py
@@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__)
def _batch_update_named_params(
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
):
- """Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
- from transformers import is_torch_xpu_available
+ """Batched weight sync — uses NCCL if communicator available, HTTP otherwise."""
+ has_communicator = getattr(self, "communicator", None) is not None
- if chunk_size is None:
- chunks = [params]
- else:
- chunks = []
- current_chunk: list[tuple[str, torch.Tensor]] = []
- current_elements = 0
- for name, weights in params:
- n_elem = weights.numel()
- if current_chunk and current_elements + n_elem > chunk_size:
- chunks.append(current_chunk)
- current_chunk = []
- current_elements = 0
- current_chunk.append((name, weights))
- current_elements += n_elem
- if current_chunk:
- chunks.append(current_chunk)
+ if has_communicator:
+ # Fast path: metadata via HTTP, tensors via NCCL
+ from transformers import is_torch_xpu_available
- for chunk in chunks:
- param_metadata = [
- {"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
- for name, weights in chunk
- ]
- url = f"{self.base_url}/batch_update_named_params/"
- response = self.session.post(url, json={"params": param_metadata})
- if response.status_code != 200:
- raise Exception(f"Request failed: {response.status_code}, {response.text}")
-
- for _name, weights in chunk:
- if is_torch_xpu_available():
- self.communicator.broadcast(weights, root=self.rank)
- else:
- self.communicator.broadcast(weights, src=self.rank)
-
- if is_torch_xpu_available():
- self.communicator.barrier()
+ if chunk_size is None:
+ chunks = [params]
else:
- self.communicator.group.barrier()
+ chunks = []
+ current_chunk: list[tuple[str, torch.Tensor]] = []
+ current_elements = 0
+ for name, weights in params:
+ n_elem = weights.numel()
+ if current_chunk and current_elements + n_elem > chunk_size:
+ chunks.append(current_chunk)
+ current_chunk = []
+ current_elements = 0
+ current_chunk.append((name, weights))
+ current_elements += n_elem
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ for chunk in chunks:
+ param_metadata = [
+ {
+ "name": name,
+ "dtype": str(weights.dtype),
+ "shape": list(weights.shape),
+ }
+ for name, weights in chunk
+ ]
+ url = f"{self.base_url}/batch_update_named_params/"
+ response = self.session.post(
+ url, json={"params": param_metadata}, timeout=120
+ )
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed: {response.status_code}, {response.text}"
+ )
+
+ for _name, weights in chunk:
+ if is_torch_xpu_available():
+ self.communicator.broadcast(weights, root=self.rank)
+ else:
+ self.communicator.broadcast(weights, src=self.rank)
+
+ if is_torch_xpu_available():
+ self.communicator.barrier()
+ else:
+ self.communicator.group.barrier()
+ else:
+ # HTTP-only path: encode tensor data in request body (no NCCL needed).
+ # Batch by byte size to avoid huge HTTP payloads.
+ MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB
+ HTTP_TIMEOUT = 120 # seconds per request
+
+ payload: list[dict] = []
+ payload_bytes = 0
+ url = f"{self.base_url}/http_update_weights/"
+
+ def _flush(p: list[dict]) -> None:
+ if not p:
+ return
+ response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT)
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed: {response.status_code}, {response.text}"
+ )
+
+ from axolotl.utils.weight_serde import encode_for_http
+
+ for name, weights in params:
+ entry = encode_for_http(name, weights)
+ entry_bytes = weights.nelement() * weights.element_size()
+
+ # Flush current batch if adding this entry would exceed limit
+ if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST:
+ _flush(payload)
+ payload = []
+ payload_bytes = 0
+
+ payload.append(entry)
+ payload_bytes += entry_bytes
+
+ _flush(payload) # send remaining
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
diff --git a/src/axolotl/prompt_strategies/ebft/__init__.py b/src/axolotl/prompt_strategies/ebft/__init__.py
new file mode 100644
index 000000000..db46af005
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/__init__.py
@@ -0,0 +1,9 @@
+"""
+module for EBFT style dataset transform strategies
+"""
+
+from functools import partial
+
+from ..base import load as load_base
+
+load = partial(load_base, module_base="axolotl.prompt_strategies.ebft")
diff --git a/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py b/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
new file mode 100644
index 000000000..26982c183
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
@@ -0,0 +1,129 @@
+"""
+Dataset transform for multi-turn chat data with structured EBFT (vLLM mode).
+
+Three variants:
+
+1. `transform` — Uses the FIRST assistant turn as the generation target.
+ Passes remaining turns as `remaining_turns` for sequential rollout.
+ The trainer generates turn 1 via GRPO/vLLM, then sequentially generates
+ subsequent assistant turns, comparing the full conversation to GT.
+
+2. `transform_last_turn` — Uses the LAST assistant turn as the target.
+ Simplest approach: the full conversation history is the prompt.
+
+3. `transform_all_turns` — Explodes each conversation into N examples
+ (one per assistant turn). Each turn is an independent training example.
+ Use with batched=True.
+
+Supports OpenAI chat format:
+ {"messages": [{"role": ..., "content": ...}, ...]}
+"""
+
+
+def transform(cfg, **kwargs):
+ """Multi-turn with sequential rollout.
+
+ Returns the first assistant turn as ground_truth, plus remaining_turns
+ for the trainer to do sequential rollout generation.
+ """
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ if not messages:
+ return {"prompt": [], "ground_truth": ""}
+
+ # Split at first assistant turn
+ prompt_msgs = []
+ first_gt = None
+ remaining = []
+
+ found_first = False
+ for msg in messages:
+ if msg["role"] == "assistant" and not found_first:
+ first_gt = msg["content"]
+ found_first = True
+ elif found_first:
+ remaining.append(msg)
+ else:
+ prompt_msgs.append(msg)
+
+ if first_gt is None:
+ return {"prompt": prompt_msgs, "ground_truth": ""}
+
+ # Store only the first assistant turn as ground_truth. The full multi-turn
+ # GT is reconstructed in the reward function via chat template rendering
+ # (using remaining_turns), which preserves role markers between turns.
+ return {
+ "prompt": prompt_msgs,
+ "ground_truth": first_gt,
+ "remaining_turns": remaining,
+ }
+
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
+
+
+def transform_last_turn(cfg, **kwargs):
+ """Single-turn: use the last assistant turn as the generation target."""
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ if not messages:
+ return {"prompt": [], "ground_truth": ""}
+
+ # Find all assistant turns
+ history = []
+ last_prompt = []
+ last_gt = ""
+ for msg in messages:
+ if msg["role"] == "assistant":
+ last_prompt = list(history)
+ last_gt = msg["content"]
+ history.append(msg)
+
+ return {
+ "prompt": last_prompt,
+ "ground_truth": last_gt,
+ }
+
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
+
+
+def transform_all_turns(cfg, **kwargs):
+ """Explode: one example per assistant turn.
+
+ Use with datasets.map(batched=True) to produce N examples from
+ each N-turn conversation.
+
+ Usage in YAML:
+ type: ebft_chat_multiturn.transform_all_turns
+ """
+
+ def transform_fn(examples, tokenizer=None):
+ all_prompts = []
+ all_ground_truths = []
+
+ messages_list = examples.get("messages", examples.get("conversations", []))
+
+ for messages in messages_list:
+ history = []
+ for msg in messages:
+ if msg["role"] == "assistant":
+ all_prompts.append(list(history))
+ all_ground_truths.append(msg["content"])
+ history.append(msg)
+
+ return {
+ "prompt": all_prompts,
+ "ground_truth": all_ground_truths,
+ }
+
+ return transform_fn, {
+ "remove_columns": "__all__",
+ "batched": True,
+ }
diff --git a/src/axolotl/prompt_strategies/ebft/ebft_opencode.py b/src/axolotl/prompt_strategies/ebft/ebft_opencode.py
new file mode 100644
index 000000000..930d314a1
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/ebft_opencode.py
@@ -0,0 +1,20 @@
+"""
+Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode.
+
+Maps the dataset's `input` (prompt) and `output` (code solution) fields
+to the format expected by the EBFT trainer (prompt + ground_truth).
+"""
+
+
+def transform(cfg, **kwargs):
+ def transform_fn(example, tokenizer=None):
+ return {
+ "prompt": [
+ {"role": "user", "content": example["input"]},
+ ],
+ "ground_truth": example["output"],
+ }
+
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
diff --git a/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py b/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
new file mode 100644
index 000000000..7c756d158
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
@@ -0,0 +1,319 @@
+"""
+Dataset transform for reasoning/thinking datasets with EBFT.
+
+Handles datasets where assistant responses contain ... reasoning
+traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs).
+
+Two variants:
+
+1. `transform` — For structured EBFT (vLLM mode):
+ Returns prompt + ground_truth with thinking tags preserved.
+ Feature matching compares full responses (thinking + answer).
+
+2. `transform_answer_only` — For structured EBFT (vLLM mode):
+ Strips ... from ground_truth, so feature matching
+ only scores the final answer portion. Use when reasoning chains
+ can vary but the answer should match.
+
+3. `transform_strided` — For strided EBFT:
+ Tokenizes the full conversation with thinking traces.
+ Optionally masks thinking tokens from CE loss (labels=-100 for think spans)
+ while still placing anchors in thinking regions for feature matching.
+
+All variants work with OpenAI chat format:
+ {"messages": [{"role": "...", "content": "...Answer"}]}
+"""
+
+import re
+
+
+def _strip_thinking(text: str) -> str:
+ """Remove ... blocks from text."""
+ return re.sub(r".*?\s*", "", text, flags=re.DOTALL).strip()
+
+
+def _extract_thinking(text: str) -> tuple[str, str]:
+ """Split text into (thinking, answer) parts."""
+ match = re.search(r"(.*?)\s*(.*)", text, flags=re.DOTALL)
+ if match:
+ return match.group(1).strip(), match.group(2).strip()
+ return "", text.strip()
+
+
+def transform(cfg, **kwargs):
+ """Full response including thinking traces for feature matching.
+
+ For datasets where assistant content has ... tags in the
+ content field. The ground_truth includes the full content (thinking + answer).
+ """
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ prompt_msgs_snapshot = None
+ ground_truth = ""
+ for msg_idx, msg in enumerate(messages):
+ if msg["role"] == "assistant":
+ prompt_msgs_snapshot = list(messages[:msg_idx])
+ ground_truth = msg["content"]
+
+ return {
+ "prompt": prompt_msgs_snapshot
+ if prompt_msgs_snapshot is not None
+ else messages[:-1],
+ "ground_truth": ground_truth,
+ }
+
+ return transform_fn, {"remove_columns": "__all__"}
+
+
+def transform_split_thinking(cfg, **kwargs):
+ """Split tags into reasoning_content field for native chat template handling.
+
+ For datasets where thinking is embedded in the content field as ....
+ Splits it into separate reasoning_content and content fields so the model's
+ chat template can format it natively (e.g., Qwen3.5's reasoning_content support).
+
+ The prompt messages are passed through with reasoning_content properly split,
+ so vLLM generation with enable_thinking=true produces comparable outputs.
+ The ground_truth is the full assistant response (thinking + answer) for
+ feature matching.
+
+ Also works for:
+ - ... tags
+ - <|begin_of_thought|>...<|end_of_thought|> tags
+ """
+ _THINKING_PAIRS = [
+ ("", ""),
+ ("", ""),
+ ("<|begin_of_thought|>", "<|end_of_thought|>"),
+ ]
+
+ def _split_msg_thinking(msg):
+ """Split thinking from assistant message content into reasoning_content.
+
+ Always includes reasoning_content key on assistant messages (empty string
+ if no thinking tags found) to ensure consistent HF dataset schema across
+ all examples in a batch.
+ """
+ if msg["role"] != "assistant":
+ return msg
+ content = msg.get("content", "")
+ # Already has reasoning_content — pass through
+ if "reasoning_content" in msg:
+ return msg
+ for open_tag, close_tag in _THINKING_PAIRS:
+ if open_tag in content and close_tag in content:
+ start = content.find(open_tag)
+ end = content.find(close_tag)
+ thinking = content[start + len(open_tag) : end].strip()
+ answer = content[end + len(close_tag) :].strip()
+ return {
+ **msg,
+ "reasoning_content": thinking,
+ "content": answer,
+ }
+ # No thinking tags — still add reasoning_content for schema consistency
+ return {**msg, "reasoning_content": ""}
+
+ def _normalize_msg(msg):
+ """Ensure every message has {role, content, reasoning_content} for HF schema consistency."""
+ return {
+ "role": msg.get("role", ""),
+ "content": msg.get("content", ""),
+ "reasoning_content": msg.get("reasoning_content", ""),
+ }
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ # Split thinking in all assistant messages, then normalize schema
+ split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages]
+
+ # Build prompt (all messages except last assistant) and ground_truth
+ prompt_msgs = []
+ prompt_msgs_snapshot = None
+ ground_truth = ""
+ for msg in split_messages:
+ if msg["role"] == "assistant":
+ prompt_msgs_snapshot = list(prompt_msgs)
+ # ground_truth is the FULL content for feature matching
+ thinking = msg.get("reasoning_content", "")
+ answer = msg.get("content", "")
+ if thinking:
+ ground_truth = f"\n{thinking}\n\n\n{answer}"
+ else:
+ ground_truth = answer
+ prompt_msgs.append(msg)
+
+ return {
+ "prompt": prompt_msgs_snapshot
+ if prompt_msgs_snapshot is not None
+ else split_messages[:-1],
+ "ground_truth": ground_truth,
+ }
+
+ return transform_fn, {"remove_columns": "__all__"}
+
+
+def transform_answer_only(cfg, **kwargs):
+ """Strip thinking from ground_truth — match features on answer only."""
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ prompt_msgs = []
+ prompt_msgs_snapshot = None
+ ground_truth = ""
+ for msg in messages:
+ if msg["role"] == "assistant":
+ prompt_msgs_snapshot = list(prompt_msgs)
+ ground_truth = _strip_thinking(msg["content"])
+ prompt_msgs.append(msg)
+
+ return {
+ "prompt": prompt_msgs_snapshot
+ if prompt_msgs_snapshot is not None
+ else messages[:-1],
+ "ground_truth": ground_truth,
+ }
+
+ return transform_fn, {"remove_columns": "__all__"}
+
+
+def transform_strided(cfg, **kwargs):
+ """For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss.
+
+ Config options (via cfg):
+ - ebft.mask_thinking_ce: bool (default False)
+ If True, set labels=-100 for tokens inside ... blocks.
+ Feature matching still uses these positions (anchors are placed everywhere
+ in the completion span). Only CE auxiliary loss is affected.
+ """
+ seq_len = cfg.sequence_len
+ mask_thinking = False
+ if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"):
+ mask_thinking = cfg.ebft.mask_thinking_ce
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ if tokenizer is None:
+ for m in messages:
+ if m.get("role") == "user":
+ return {"prompt": m["content"]}
+ return {"prompt": str(messages)}
+
+ pad_id = (
+ tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None
+ else tokenizer.eos_token_id
+ )
+
+ # Tokenize the full conversation with the chat template
+ full_text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ full_enc = tokenizer(
+ full_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )
+ input_ids = full_enc["input_ids"]
+
+ # Build labels: -100 for non-assistant tokens
+ labels = [-100] * len(input_ids)
+
+ # Find assistant turn boundaries using incremental tokenization.
+ # Only the FINAL assistant turn is marked as trainable.
+ prefix_messages = []
+ final_start = None
+ final_end = None
+ for msg in messages:
+ if msg["role"] == "assistant":
+ prefix_text = tokenizer.apply_chat_template(
+ prefix_messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ prefix_ids = tokenizer(
+ prefix_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )["input_ids"]
+ start = len(prefix_ids)
+
+ prefix_messages.append(msg)
+ with_turn_text = tokenizer.apply_chat_template(
+ prefix_messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ with_turn_ids = tokenizer(
+ with_turn_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )["input_ids"]
+ end = len(with_turn_ids)
+
+ # Record this turn's boundaries; only the last one will be used
+ final_start = start
+ final_end = end
+ else:
+ prefix_messages.append(msg)
+
+ # Mark only the final assistant turn as trainable
+ if final_start is not None and final_end is not None:
+ for i in range(final_start, min(final_end, len(labels))):
+ labels[i] = input_ids[i]
+
+ # Optionally mask ... tokens within this turn.
+ # Find think spans by scanning for and token IDs
+ # directly in the input_ids (robust to tokenization alignment).
+ if mask_thinking:
+ think_open_id = tokenizer.convert_tokens_to_ids("")
+ think_close_id = tokenizer.convert_tokens_to_ids("")
+ if think_open_id != tokenizer.unk_token_id:
+ # Scan from before the assistant turn start to catch
+ # tags that are part of the template prefix
+ scan_start = max(0, final_start - 5)
+ in_think = False
+ for i in range(scan_start, min(final_end, len(labels))):
+ if input_ids[i] == think_open_id:
+ in_think = True
+ if in_think and i >= final_start:
+ labels[i] = -100
+ if input_ids[i] == think_close_id:
+ in_think = False
+ if i >= final_start:
+ labels[i] = -100
+
+ # Derive prompt_length
+ prompt_length = len(input_ids)
+ for i, lbl in enumerate(labels):
+ if lbl != -100:
+ prompt_length = i
+ break
+
+ # Pad
+ pad_len = seq_len - len(input_ids)
+ attention_mask = [1] * len(input_ids) + [0] * pad_len
+ labels = labels + [-100] * pad_len
+ input_ids = input_ids + [pad_id] * pad_len
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "prompt_length": prompt_length,
+ }
+
+ return transform_fn, {"remove_columns": "__all__"}
diff --git a/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py b/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
new file mode 100644
index 000000000..f1d0c01f9
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
@@ -0,0 +1,110 @@
+"""
+Dataset transform for multi-turn chat data with strided EBFT.
+
+Tokenizes conversations using the model's chat template, producing input_ids
+with labels=-100 for system/user turns and real labels for assistant turns.
+The strided trainer places anchors only within assistant completion spans.
+
+Works with datasets in OpenAI chat format:
+ [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
+"""
+
+
+def transform(cfg, **kwargs):
+ seq_len = cfg.sequence_len
+
+ def transform_fn(example, tokenizer=None):
+ messages = example.get("messages", example.get("conversations", []))
+
+ if tokenizer is None:
+ # For preview: just return the first user message
+ for m in messages:
+ if m.get("role") == "user":
+ return {"prompt": m["content"]}
+ return {"prompt": str(messages)}
+
+ pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
+
+ # Tokenize the full conversation with the chat template
+ full_text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ full_enc = tokenizer(
+ full_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )
+ input_ids = full_enc["input_ids"]
+
+ # Build labels: -100 for everything except assistant turns.
+ # Strategy: tokenize incrementally to find assistant turn boundaries.
+ labels = [-100] * len(input_ids)
+
+ # Tokenize prefix up to each assistant turn to find boundaries
+ prefix_messages = []
+ for msg in messages:
+ if msg["role"] == "assistant":
+ # Tokenize prefix (everything before this assistant turn + generation prompt)
+ prefix_text = tokenizer.apply_chat_template(
+ prefix_messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ prefix_ids = tokenizer(
+ prefix_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )["input_ids"]
+ start = len(prefix_ids)
+
+ # Tokenize prefix + this assistant turn
+ prefix_messages.append(msg)
+ with_turn_text = tokenizer.apply_chat_template(
+ prefix_messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ with_turn_ids = tokenizer(
+ with_turn_text,
+ truncation=True,
+ max_length=seq_len,
+ add_special_tokens=False,
+ return_tensors=None,
+ )["input_ids"]
+ end = len(with_turn_ids)
+
+ # Mark assistant tokens as trainable
+ for i in range(start, min(end, len(labels))):
+ labels[i] = input_ids[i]
+ else:
+ prefix_messages.append(msg)
+
+ # Derive prompt_length as the position of the first non-masked label
+ prompt_length = len(input_ids) # default: all masked
+ for i, lbl in enumerate(labels):
+ if lbl != -100:
+ prompt_length = i
+ break
+
+ # Pad to seq_len
+ pad_len = seq_len - len(input_ids)
+ attention_mask = [1] * len(input_ids) + [0] * pad_len
+ labels = labels + [-100] * pad_len
+ input_ids = input_ids + [pad_id] * pad_len
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "prompt_length": prompt_length,
+ }
+
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
diff --git a/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py b/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
new file mode 100644
index 000000000..767650575
--- /dev/null
+++ b/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
@@ -0,0 +1,80 @@
+"""
+Dataset transform for structured (prompt, completion) data with strided EBFT.
+
+Tokenizes prompt and completion separately, concatenates into a single
+input_ids sequence, and marks prompt tokens with labels=-100 so the
+strided trainer knows where to place anchors (completion span only).
+
+Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
+"""
+
+
+def transform(cfg, **kwargs):
+ seq_len = cfg.sequence_len
+
+ def transform_fn(example, tokenizer=None):
+ # Extract prompt and completion from the example
+ prompt_text = example.get(
+ "input", example.get("prompt", example.get("question", ""))
+ )
+ completion_text = example.get(
+ "output", example.get("completion", example.get("answer", ""))
+ )
+
+ if tokenizer is None:
+ return {"prompt": prompt_text}
+
+ pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
+
+ # Tokenize prompt and completion separately
+ prompt_enc = tokenizer(
+ prompt_text,
+ truncation=False,
+ add_special_tokens=True,
+ return_tensors=None,
+ )
+ completion_enc = tokenizer(
+ completion_text,
+ truncation=False,
+ add_special_tokens=False,
+ return_tensors=None,
+ )
+
+ prompt_ids = prompt_enc["input_ids"]
+ completion_ids = completion_enc["input_ids"]
+
+ # Truncate to fit within seq_len (prioritize keeping prompt + some completion)
+ total_len = len(prompt_ids) + len(completion_ids)
+ if total_len > seq_len:
+ # Truncate completion first, then prompt if needed
+ max_completion = seq_len - len(prompt_ids)
+ if max_completion < 1:
+ # Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
+ prompt_ids = prompt_ids[: seq_len - 1]
+ completion_ids = completion_ids[:1]
+ else:
+ completion_ids = completion_ids[:max_completion]
+
+ input_ids = prompt_ids + completion_ids
+ prompt_length = len(prompt_ids)
+
+ # Labels: -100 for prompt tokens, input_ids for completion tokens
+ labels = [-100] * prompt_length + completion_ids
+
+ # Pad to seq_len
+ pad_len = seq_len - len(input_ids)
+ attention_mask = [1] * len(input_ids) + [0] * pad_len
+ labels = labels + [-100] * pad_len
+ input_ids = input_ids + [pad_id] * pad_len
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "prompt_length": prompt_length,
+ }
+
+ # Signal to remove all original columns (filtered to existing ones at map time)
+ return transform_fn, {
+ "remove_columns": "__all__",
+ }
diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py
index 9ce4d2771..f4fcfa190 100644
--- a/src/axolotl/scripts/vllm_serve_lora.py
+++ b/src/axolotl/scripts/vllm_serve_lora.py
@@ -241,6 +241,23 @@ def main(script_args: ScriptArguments):
app = FastAPI(lifespan=lifespan)
+ # --- Access logging middleware ---
+ import time as _time
+
+ @app.middleware("http")
+ async def access_log_middleware(request, call_next):
+ t0 = _time.monotonic()
+ response = await call_next(request)
+ elapsed = _time.monotonic() - t0
+ logger.info(
+ "%s %s %d %.3fs",
+ request.method,
+ request.url.path,
+ response.status_code,
+ elapsed,
+ )
+ return response
+
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
@@ -300,7 +317,11 @@ def main(script_args: ScriptArguments):
import vllm
from packaging.version import Version
- from vllm.sampling_params import GuidedDecodingParams
+
+ try:
+ from vllm.sampling_params import GuidedDecodingParams
+ except ImportError:
+ GuidedDecodingParams = None # not available in vLLM 0.17+
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
prompts: list[dict[str, Any]] = []
@@ -362,7 +383,12 @@ def main(script_args: ScriptArguments):
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
- all_outputs = [conn.recv() for conn in connections]
+ # Use run_in_executor so blocking recv() doesn't freeze the event loop
+ # (allows /set_lora_adapter/ and other endpoints to be served concurrently)
+ loop = asyncio.get_running_loop()
+ all_outputs = await asyncio.gather(
+ *(loop.run_in_executor(None, conn.recv) for conn in connections)
+ )
all_outputs = [
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
]
@@ -404,7 +430,10 @@ def main(script_args: ScriptArguments):
}
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
- all_outputs = [conn.recv() for conn in connections]
+ loop = asyncio.get_running_loop()
+ all_outputs = await asyncio.gather(
+ *(loop.run_in_executor(None, conn.recv) for conn in connections)
+ )
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs))
@@ -474,11 +503,51 @@ def main(script_args: ScriptArguments):
)
return {"message": f"Batch update for {len(params_list)} params"}
+ class HTTPWeightUpdateRequest(BaseModel):
+ """Weight update via HTTP (no NCCL needed)."""
+
+ params: list[
+ dict
+ ] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
+
+ @app.post("/http_update_weights/")
+ async def http_update_weights(request: HTTPWeightUpdateRequest):
+ """Update model weights via HTTP — no NCCL communicator required.
+
+ Tensor data is sent as base64-encoded raw bytes in the request body.
+ Slower than NCCL for large models but works without cross-process setup.
+ """
+ from axolotl.utils.weight_serde import (
+ decode_from_http,
+ encode_for_ipc,
+ )
+
+ weights_to_load = [decode_from_http(p) for p in request.params]
+
+ # Send all weights in a single IPC call. Tensors don't survive
+ # vLLM's multiproc IPC, so serialize as raw bytes + metadata.
+ param_entries = [
+ encode_for_ipc(name, weight) for name, weight in weights_to_load
+ ]
+ kwargs = {
+ "method": "http_load_weights_batch",
+ "kwargs": {"params": param_entries},
+ }
+ msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
+ loop = asyncio.get_running_loop()
+ await asyncio.gather(
+ *(loop.run_in_executor(None, c.send, msg) for c in connections)
+ )
+ return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
+
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"})
- results = [conn.recv() for conn in connections]
+ loop = asyncio.get_running_loop()
+ results = await asyncio.gather(
+ *(loop.run_in_executor(None, conn.recv) for conn in connections)
+ )
return {"message": f"Reset prefix cache: {all(results)}"}
@app.post("/close_communicator/")
diff --git a/src/axolotl/scripts/vllm_worker_ext.py b/src/axolotl/scripts/vllm_worker_ext.py
index 386460df1..11f8e6ceb 100644
--- a/src/axolotl/scripts/vllm_worker_ext.py
+++ b/src/axolotl/scripts/vllm_worker_ext.py
@@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
model = self.model_runner.model
params_dict = dict(model.named_parameters())
+ # Handle VLM models where trainer and vLLM use different prefixes.
+ # Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..."
+ # vLLM (Qwen3.5): "language_model.model.layers.X..."
+ if name not in params_dict:
+ # Try common prefix remappings
+ for src_prefix, dst_prefix in [
+ ("model.language_model.layers.", "language_model.model.layers."),
+ ("model.layers.", "language_model.model.layers."),
+ ]:
+ if name.startswith(src_prefix):
+ name = dst_prefix + name[len(src_prefix) :]
+ break
+
# Check if this is a simple direct param (exists as-is)
if name in params_dict:
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
@@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
return
# Fallback: try load_weights (may work for non-stacked params)
- logger.warning("Falling back to load_weights for param: %s", name)
+ # Log the actual param names available for debugging
+ sample_keys = [
+ k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k
+ ][:3]
+ logger.warning(
+ "Falling back to load_weights for param: %s (sample vLLM keys: %s)",
+ name,
+ sample_keys,
+ )
model.load_weights(weights=[(name, weight)])
def update_named_param(self, name, dtype, shape):
@@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
# Load weights using direct set (handles stacked params)
for name, weight in weights_to_load:
self._direct_set_weight(name, weight)
+
+ def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]):
+ """Load weights received via HTTP (no NCCL needed)."""
+ for name, weight in weights:
+ self._direct_set_weight(name, weight.to(self.device))
+
+ def http_load_weight(self, **kwargs):
+ """Load a single weight received via HTTP (no NCCL needed).
+
+ Reconstructs the tensor from raw bytes since tensors don't survive
+ vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights``
+ which handles TP sharding and stacked-param packing automatically.
+ """
+ from axolotl.utils.weight_serde import decode_from_ipc
+
+ name, weight = decode_from_ipc(kwargs)
+ model = self.model_runner.model
+ model.load_weights(weights=[(name, weight)])
+
+ def http_load_weights_batch(self, params: list[dict]):
+ """Load multiple weights in a single IPC call.
+
+ Uses vLLM's ``load_weights`` which handles TP sharding automatically.
+ """
+ from axolotl.utils.weight_serde import decode_from_ipc
+
+ model = self.model_runner.model
+ weights = [decode_from_ipc(p) for p in params]
+ model.load_weights(weights=weights)
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 522dd7e28..774aa1cec 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -138,7 +138,11 @@ def setup_reference_model(
model_ref = None # explicit setting to None
else:
reference_model: bool = True
- if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
+ trl_cfg = getattr(cfg, "trl", None)
+ if (
+ cfg.rl in {RLType.GRPO, RLType.EBFT}
+ and getattr(trl_cfg, "beta", 0) == 0
+ ):
reference_model = False
# load the model again for model_ref/baseline
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
@@ -206,7 +210,7 @@ def execute_training(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
- gather_outputs=cfg.rl is RLType.GRPO,
+ gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT},
device_mesh=trainer.accelerator.torch_device_mesh,
)
)
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index 36370ef13..afdb7f2a2 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
].append(pred_step_text)
row_index += 1
if logger == "wandb":
- # type: ignore[attr-defined]
- wandb.run.log(
+ wandb.run.log( # type: ignore[attr-defined]
{
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
table_data
@@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
- artifact = wandb.Artifact(
- f"config-{wandb.run.id}", type="axolotl-config"
+ artifact = wandb.Artifact( # type: ignore[attr-defined]
+ f"config-{wandb.run.id}", # type: ignore[attr-defined]
+ type="axolotl-config",
)
artifact.add_file(temp_file.name)
- wandb.log_artifact(artifact)
- wandb.save(temp_file.name)
+ wandb.log_artifact(artifact) # type: ignore[attr-defined]
+ wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
@@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush()
- artifact = wandb.Artifact(
- f"chat-template-{wandb.run.id}", type="jinja-template"
+ artifact = wandb.Artifact( # type: ignore[attr-defined]
+ f"chat-template-{wandb.run.id}", # type: ignore[attr-defined]
+ type="jinja-template",
)
artifact.add_file(temp_ct_file.name)
- wandb.log_artifact(artifact)
- wandb.save(temp_ct_file.name)
+ wandb.log_artifact(artifact) # type: ignore[attr-defined]
+ wandb.save(temp_ct_file.name) # type: ignore[attr-defined]
LOG.info(
"The chat_template_jinja has been saved to the WandB run under files."
)
@@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
else:
skip_upload = True
if not skip_upload:
- artifact = wandb.Artifact(
- f"deepspeed-config-{wandb.run.id}",
+ artifact = wandb.Artifact( # type: ignore[attr-defined]
+ f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined]
type="deepspeed-config",
)
artifact.add_file(temp_file.name)
- wandb.log_artifact(artifact)
- wandb.save(temp_file.name)
+ wandb.log_artifact(artifact) # type: ignore[attr-defined]
+ wandb.save(temp_file.name) # type: ignore[attr-defined]
LOG.info(
"The DeepSpeed config has been saved to the WandB run under files."
)
diff --git a/src/axolotl/utils/callbacks/generation.py b/src/axolotl/utils/callbacks/generation.py
index 439258c8b..da36b9ad0 100644
--- a/src/axolotl/utils/callbacks/generation.py
+++ b/src/axolotl/utils/callbacks/generation.py
@@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback):
if not getattr(cfg, "generate_samples", False):
return
- dataloader = None
- try:
- if getattr(self.trainer, "eval_dataset", None) is not None:
- dataloader = self.trainer.get_eval_dataloader()
- LOG.info(
- f"Using eval dataloader for generation at step {state.global_step}"
- )
- except Exception as e:
- LOG.warning(f"Could not get eval dataloader: {e}")
- dataloader = None
-
- if dataloader is None:
- dataloader = self.trainer.get_train_dataloader()
+ dataloader = None
+ try:
+ if getattr(self.trainer, "eval_dataset", None) is not None:
+ dataloader = self.trainer.get_eval_dataloader()
LOG.info(
- f"Using train dataloader for generation at step {state.global_step}"
+ f"Using eval dataloader for generation at step {state.global_step}"
)
+ except Exception as e:
+ LOG.warning(f"Could not get eval dataloader: {e}")
+ dataloader = None
- samples = generate_samples(
- model=self.trainer.model,
- tokenizer=self.trainer.processing_class,
- dataloader=dataloader,
- num_generation_samples=getattr(cfg, "num_generation_samples", 3),
- max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
- temperature=getattr(cfg, "generation_temperature", 0.7),
- top_p=getattr(cfg, "generation_top_p", None),
- top_k=getattr(cfg, "generation_top_k", None),
- do_sample=getattr(cfg, "generation_do_sample", True),
- prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
+ if dataloader is None:
+ dataloader = self.trainer.get_train_dataloader()
+ LOG.info(
+ f"Using train dataloader for generation at step {state.global_step}"
)
- self._log_samples(samples, state.global_step)
+
+ samples = generate_samples(
+ model=self.trainer.model,
+ tokenizer=self.trainer.processing_class,
+ dataloader=dataloader,
+ num_generation_samples=getattr(cfg, "num_generation_samples", 3),
+ max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
+ temperature=getattr(cfg, "generation_temperature", 0.7),
+ top_p=getattr(cfg, "generation_top_p", None),
+ top_k=getattr(cfg, "generation_top_k", None),
+ do_sample=getattr(cfg, "generation_do_sample", True),
+ prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
+ )
+ self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples to console and W&B."""
@@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback):
try:
import wandb
- if wandb.run is not None:
- wandb.log(
+ if wandb.run is not None: # type: ignore[attr-defined]
+ wandb.log( # type: ignore[attr-defined]
{
- f"samples/sample_{i + 1}": wandb.Html(
+ f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined]
f"{wandb_text}"
)
},
diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py
index 2c386f35e..ef91e1124 100644
--- a/src/axolotl/utils/data/rl.py
+++ b/src/axolotl/utils/data/rl.py
@@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer
from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
+from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.lock import FileLockLoader
@@ -173,7 +174,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len
- if rl in {RLType.GRPO, RLType.GDPO}:
+ if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}:
return True
raise ValueError("Unknown RL type")
@@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
+ elif cfg.rl is RLType.EBFT:
+ ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
+ # Handle remove_columns: "__all__" removes all original columns,
+ # or filter a list to only columns that exist in the dataset
+ if "remove_columns" in map_kwargs:
+ ds_columns = (
+ dataset.column_names
+ if isinstance(dataset, Dataset)
+ else dataset[split].column_names
+ if isinstance(dataset, DatasetDict)
+ else []
+ )
+ if map_kwargs["remove_columns"] == "__all__":
+ map_kwargs["remove_columns"] = list(ds_columns)
+ else:
+ map_kwargs["remove_columns"] = [
+ c for c in map_kwargs["remove_columns"] if c in ds_columns
+ ]
split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
)
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index 34fd9ba2c..982e3e419 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__)
+class EBFTConfig(BaseModel):
+ """Configuration for Energy-Based Fine-Tuning (EBFT)"""
+
+ feature_layers: list[float] = Field(
+ default=[0.25, 0.5, 0.75],
+ json_schema_extra={
+ "description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])"
+ },
+ )
+ embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = (
+ Field(
+ default="last_token",
+ json_schema_extra={
+ "description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'"
+ },
+ )
+ )
+ use_whitening: bool = Field(
+ default=False,
+ json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
+ )
+ alignment_coef: float = Field(
+ default=1.0,
+ json_schema_extra={
+ "description": "Coefficient for alignment reward (cosine similarity with ground truth)"
+ },
+ )
+ diversity_coef: float = Field(
+ default=1.0,
+ json_schema_extra={
+ "description": "Coefficient for diversity penalty (pairwise similarity between samples)"
+ },
+ )
+ ce_coef: float = Field(
+ default=0.0,
+ json_schema_extra={
+ "description": "Cross-entropy loss coefficient on ground-truth tokens"
+ },
+ )
+ adaptive_max_tokens: bool = Field(
+ default=True,
+ json_schema_extra={
+ "description": "Set per-batch max_tokens based on ground-truth length"
+ },
+ )
+ gt_length_multiplier: float = Field(
+ default=1.5,
+ ge=0.1,
+ json_schema_extra={
+ "description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
+ },
+ )
+
+ # Strided mode fields (for unstructured text)
+ mode: Literal["structured", "strided"] = Field(
+ default="structured",
+ json_schema_extra={
+ "description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
+ },
+ )
+ stride: int = Field(
+ default=8,
+ ge=1,
+ json_schema_extra={"description": "Stride between anchor points (tokens)"},
+ )
+ context_length: int = Field(
+ default=8,
+ ge=1,
+ json_schema_extra={"description": "Context window size per block"},
+ )
+ generate_max_len: int = Field(
+ default=8,
+ ge=1,
+ json_schema_extra={"description": "Tokens to generate per block"},
+ )
+ n_samples_per_prompt: int = Field(
+ default=4,
+ ge=1,
+ json_schema_extra={"description": "Independent rollouts per document"},
+ )
+ temperature: float = Field(
+ default=0.6,
+ ge=0.0,
+ json_schema_extra={
+ "description": "Sampling temperature for strided generation"
+ },
+ )
+ top_p: float = Field(
+ default=1.0,
+ ge=0.0,
+ le=1.0,
+ json_schema_extra={"description": "Top-p nucleus sampling threshold"},
+ )
+ rl_coef: float = Field(
+ default=1.0,
+ json_schema_extra={"description": "RL policy gradient loss coefficient"},
+ )
+ advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
+ default="rloo",
+ json_schema_extra={
+ "description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
+ },
+ )
+ min_completion_prefix: int = Field(
+ default=0,
+ ge=0,
+ json_schema_extra={
+ "description": "Minimum tokens into completion before placing anchors. "
+ "Skips anchors too close to the prompt boundary where features are dominated by prompt context."
+ },
+ )
+
+
class AxolotlInputConfig(
ModelInputConfig,
ModelOutputConfig,
@@ -131,7 +244,7 @@ class AxolotlInputConfig(
rl: RLType | None = Field(
default=None,
json_schema_extra={
- "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'"
+ "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'"
},
)
trl: TRLConfig | None = Field(
@@ -140,6 +253,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
+ ebft: EBFTConfig | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
+ },
+ )
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py
index 40fa314f4..7ffa793f2 100644
--- a/src/axolotl/utils/schemas/enums.py
+++ b/src/axolotl/utils/schemas/enums.py
@@ -35,6 +35,7 @@ class RLType(str, Enum):
ORPO = "orpo"
KTO = "kto"
SIMPO = "simpo"
+ EBFT = "ebft"
class ChatTemplate(str, Enum):
diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py
index 2d7c36f96..4ef42db66 100644
--- a/src/axolotl/utils/schemas/trl.py
+++ b/src/axolotl/utils/schemas/trl.py
@@ -1,6 +1,6 @@
"""Pydantic models for TRL trainer configuration"""
-from typing import Literal
+from typing import Any, Literal
from pydantic import BaseModel, Field
@@ -133,6 +133,20 @@ class TRLConfig(BaseModel):
"description": "Penalty for tokens that appear in prompt and generated text."
},
)
+ generation_kwargs: dict[str, Any] | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Additional generation parameters passed to vLLM SamplingParams. "
+ "Useful for stop_token_ids, seed, frequency_penalty, etc."
+ },
+ )
+ chat_template_kwargs: dict[str, Any] | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Additional kwargs for the chat template. "
+ "E.g., {enable_thinking: false} for Qwen3.5 models."
+ },
+ )
num_iterations: int | None = Field(
default=None,
json_schema_extra={
diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py
index c902d8703..c7eeb6fa4 100644
--- a/src/axolotl/utils/schemas/validation.py
+++ b/src/axolotl/utils/schemas/validation.py
@@ -1482,6 +1482,124 @@ class DistributedValidationMixin:
return self
+class EBFTValidationMixin:
+ """Validation for EBFT (Energy-Based Fine-Tuning) configuration."""
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_config_required(cls, data):
+ """rl: ebft requires an ebft config section."""
+ if data.get("rl") == "ebft" and not data.get("ebft"):
+ raise ValueError(
+ "`ebft` config section is required when `rl: ebft` is set. "
+ "Add an `ebft:` section with at least `mode: structured` or `mode: strided`."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_torch_compile(cls, data):
+ """torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles
+ and CheckpointErrors. The flex_attention kernel compiles itself internally —
+ whole-model torch.compile is not needed and actively harmful."""
+ if (
+ data.get("rl") == "ebft"
+ and data.get("torch_compile") is True
+ and data.get("ebft", {}).get("mode") == "strided"
+ ):
+ if data.get("gradient_checkpointing"):
+ raise ValueError(
+ "EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` "
+ "causes CheckpointError (BlockMask metadata mismatch during recomputation). "
+ "Remove `torch_compile` — the flex_attention kernel compiles itself internally."
+ )
+ LOG.warning(
+ "EBFT strided mode: `torch_compile: true` causes dynamo recompiles from "
+ "variable sequence lengths across steps. Consider removing it — "
+ "flex_attention compiles itself internally."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_gradient_checkpointing_reentrant(cls, data):
+ """flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
+ if (
+ data.get("rl") == "ebft"
+ and data.get("ebft", {}).get("mode") == "strided"
+ and data.get("flex_attention")
+ and data.get("gradient_checkpointing")
+ ):
+ gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
+ if not gc_kwargs.get("use_reentrant"):
+ LOG.warning(
+ "EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
+ "gradient_checkpointing_kwargs (required for flex_attention compatibility). "
+ "Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
+ )
+ if data.get("gradient_checkpointing_kwargs") is None:
+ data["gradient_checkpointing_kwargs"] = {}
+ data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_activation_offloading(cls, data):
+ """activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
+ which conflicts with flex_attention's use_reentrant requirement."""
+ if (
+ data.get("rl") == "ebft"
+ and data.get("ebft", {}).get("mode") == "strided"
+ and data.get("activation_offloading") is True
+ and data.get("flex_attention")
+ ):
+ raise ValueError(
+ "EBFT strided mode: `activation_offloading: true` is incompatible with "
+ "`flex_attention: true`. Activation offloading replaces gradient checkpointing "
+ "with FSDP-style wrapping that conflicts with flex_attention's reentrant "
+ "checkpoint requirement. Remove `activation_offloading` — the strided trainer "
+ "uses micro-batched forward passes for memory efficiency instead."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_strided_sequence_len(cls, data):
+ """Warn if sequence_len is too large for single-GPU strided EBFT."""
+ if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided":
+ return data
+ ebft = data.get("ebft", {})
+ seq_len = data.get("sequence_len", 512)
+ n_samples = ebft.get("n_samples_per_prompt", 4)
+ gen_len = ebft.get("generate_max_len", 8)
+ stride = ebft.get("stride", 8)
+ ctx_len = ebft.get("context_length", 8)
+ max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
+ full_seq = seq_len + max_blocks * gen_len
+ # Rough estimate: 8.7 GB per sample at S=3900 for 1B model
+ if full_seq * n_samples > 20000:
+ LOG.warning(
+ f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = "
+ f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM "
+ f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_ebft_strided_dataset_split(cls, data):
+ """Warn about the common `train_on_split` mistake (silently ignored by schema)."""
+ datasets = data.get("datasets", [])
+ for ds in datasets or []:
+ if isinstance(ds, dict) and ds.get("train_on_split"):
+ LOG.warning(
+ f"Dataset has `train_on_split: {ds['train_on_split']}` — this field "
+ f"is not recognized and will be silently ignored. "
+ f"Use `split: {ds['train_on_split']}` instead."
+ )
+ return data
+
+
class GRPOVllmValidationMixin:
"""Validation mixin for vllm when using GRPO."""
@@ -1507,6 +1625,7 @@ class ValidationMixin(
PretrainingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
+ EBFTValidationMixin,
GRPOVllmValidationMixin,
):
"""Full validation mixin for Axolotl configuration."""
diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py
index c0aa48d66..5198d4173 100644
--- a/src/axolotl/utils/schemas/vllm.py
+++ b/src/axolotl/utils/schemas/vllm.py
@@ -57,6 +57,13 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Reasoning parser for VLLM"},
)
+ enforce_eager: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Disable CUDA graph capture in vLLM. Required for models with "
+ "causal_conv1d (e.g., Qwen3.5 hybrid linear attention)."
+ },
+ )
serve_module: str | None = Field(
default=None,
json_schema_extra={
diff --git a/src/axolotl/utils/weight_serde.py b/src/axolotl/utils/weight_serde.py
new file mode 100644
index 000000000..d4b804681
--- /dev/null
+++ b/src/axolotl/utils/weight_serde.py
@@ -0,0 +1,94 @@
+"""Serialize / deserialize tensors for HTTP and IPC weight sync.
+
+NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire
+and reconstructed at the destination. All encode/decode helpers live here so
+the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and
+vllm_worker_ext.py.
+"""
+
+import base64
+
+import torch
+
+
+def encode_for_http(name: str, weight: torch.Tensor) -> dict:
+ """Encode a named parameter for JSON transport over HTTP.
+
+ Returns a dict with keys: name, dtype (original), shape, data (base64).
+ bf16 tensors are sent as fp16 bytes; the original dtype is preserved in
+ the ``dtype`` field so the receiver can cast back.
+ """
+ w_cpu = weight.contiguous().cpu()
+ orig_dtype = str(weight.dtype)
+ if w_cpu.dtype == torch.bfloat16:
+ w_cpu = w_cpu.half()
+ raw = w_cpu.numpy().tobytes()
+ return {
+ "name": name,
+ "dtype": orig_dtype,
+ "shape": list(weight.shape),
+ "data": base64.b64encode(raw).decode("ascii"),
+ }
+
+
+def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]:
+ """Decode an HTTP-encoded weight entry back to a named tensor.
+
+ Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the
+ original dtype stored in ``entry["dtype"]``.
+ """
+ target_dtype = getattr(torch, entry["dtype"].split(".")[-1])
+ shape = tuple(entry["shape"])
+ raw = base64.b64decode(entry["data"])
+
+ n_elements = 1
+ for s in shape:
+ n_elements *= s
+ wire_bytes_per_elem = len(raw) // max(n_elements, 1)
+ if wire_bytes_per_elem == 2:
+ wire_dtype = torch.float16
+ elif wire_bytes_per_elem == 4:
+ wire_dtype = torch.float32
+ else:
+ wire_dtype = target_dtype
+
+ weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape)
+ if wire_dtype != target_dtype:
+ weight = weight.to(target_dtype)
+ return entry["name"], weight
+
+
+def encode_for_ipc(name: str, weight: torch.Tensor) -> dict:
+ """Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64).
+
+ Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype
+ (original), shape. bf16 tensors are serialized as fp16.
+ """
+ w = weight.contiguous()
+ target_dtype = str(w.dtype).split(".")[-1]
+ if w.dtype == torch.bfloat16:
+ w = w.half()
+ wire_dtype = str(w.dtype).split(".")[-1]
+ return {
+ "name": name,
+ "data": w.numpy().tobytes(),
+ "dtype": wire_dtype,
+ "target_dtype": target_dtype,
+ "shape": list(weight.shape),
+ }
+
+
+def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]:
+ """Decode an IPC-encoded weight entry back to a named tensor.
+
+ Handles optional ``target_dtype`` for backward compatibility with older
+ serve code that may not include it.
+ """
+ wire_dtype = getattr(torch, entry["dtype"])
+ weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape(
+ entry["shape"]
+ )
+ target_dtype = entry.get("target_dtype")
+ if target_dtype and target_dtype != entry["dtype"]:
+ weight = weight.to(getattr(torch, target_dtype))
+ return entry["name"], weight
diff --git a/tests/test_ebft_kernels.py b/tests/test_ebft_kernels.py
new file mode 100644
index 000000000..2c05a887a
--- /dev/null
+++ b/tests/test_ebft_kernels.py
@@ -0,0 +1,294 @@
+"""Correctness tests for fused EBFT Triton kernels."""
+
+import pytest
+import torch
+import torch.nn.functional as F
+
+from axolotl.core.trainers.ebft.kernels import (
+ fused_cosine_similarity,
+ fused_diversity_penalty,
+ fused_log_softmax_gather,
+ fused_reinforce_loss,
+)
+
+# Skip all tests if CUDA not available
+pytestmark = pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
+)
+
+DEVICE = "cuda"
+
+
+# ---------------------------------------------------------------------------
+# 1. fused_log_softmax_gather
+# ---------------------------------------------------------------------------
+class TestFusedLogSoftmaxGather:
+ def _reference(self, logits, labels):
+ """PyTorch reference: log_softmax + gather."""
+ lp = F.log_softmax(logits.float(), dim=-1)
+ return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
+
+ def test_basic_correctness(self):
+ B, S, V = 2, 16, 1024
+ logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
+ labels = torch.randint(0, V, (B, S), device=DEVICE)
+
+ ref = self._reference(logits, labels)
+ out = fused_log_softmax_gather(logits, labels)
+
+ torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
+
+ def test_large_vocab(self):
+ """Test with realistic vocab size (128K)."""
+ B, S, V = 1, 8, 128256
+ logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
+ labels = torch.randint(0, V, (B, S), device=DEVICE)
+
+ ref = self._reference(logits, labels)
+ out = fused_log_softmax_gather(logits, labels)
+
+ torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
+
+ def test_fp32_input(self):
+ B, S, V = 2, 8, 512
+ logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
+ labels = torch.randint(0, V, (B, S), device=DEVICE)
+
+ ref = self._reference(logits, labels)
+ out = fused_log_softmax_gather(logits, labels)
+
+ torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
+
+ def test_output_is_negative(self):
+ """log_softmax values should always be <= 0."""
+ B, S, V = 4, 32, 2048
+ logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
+ labels = torch.randint(0, V, (B, S), device=DEVICE)
+
+ out = fused_log_softmax_gather(logits, labels)
+ assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
+
+ def test_extreme_logits(self):
+ """Test numerical stability with very large/small logits."""
+ B, S, V = 1, 4, 256
+ logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
+ logits[:, 0, :] = 1000.0 # very large
+ logits[:, 1, :] = -1000.0 # very small
+ logits[:, 2, 0] = 1000.0 # one hot-ish
+ labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
+
+ ref = self._reference(logits, labels)
+ out = fused_log_softmax_gather(logits, labels)
+
+ assert torch.isfinite(out).all(), "Should handle extreme values"
+ torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
+
+ def test_2d_input(self):
+ """Test with pre-flattened (N, V) input."""
+ N, V = 64, 4096
+ logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
+ labels = torch.randint(0, V, (N,), device=DEVICE)
+
+ ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
+ out = fused_log_softmax_gather(logits, labels)
+
+ torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
+
+
+# ---------------------------------------------------------------------------
+# 2. fused_reinforce_loss
+# ---------------------------------------------------------------------------
+class TestFusedReinforceLoss:
+ def _reference(self, logps, advantages, mask):
+ """PyTorch reference implementation."""
+ loss_per_token = -logps * advantages
+ return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
+
+ def test_basic_correctness(self):
+ N = 1024
+ logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
+
+ ref = self._reference(logps, advs, mask)
+ out = fused_reinforce_loss(logps, advs, mask)
+
+ torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
+
+ def test_2d_input(self):
+ """Test with (B, S) shaped inputs."""
+ B, S = 4, 256
+ logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
+ advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
+ mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
+
+ ref = self._reference(logps, advs, mask)
+ out = fused_reinforce_loss(logps, advs, mask)
+
+ torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
+
+ def test_all_masked(self):
+ """All-zero mask should return 0."""
+ N = 512
+ logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
+
+ out = fused_reinforce_loss(logps, advs, mask)
+ assert out.item() == 0.0
+
+ def test_all_unmasked(self):
+ N = 512
+ logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
+
+ ref = self._reference(logps, advs, mask)
+ out = fused_reinforce_loss(logps, advs, mask)
+
+ torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
+
+ def test_large(self):
+ """Test with realistic size (4 * 3000 tokens)."""
+ N = 12000
+ logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
+ mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
+
+ ref = self._reference(logps, advs, mask)
+ out = fused_reinforce_loss(logps, advs, mask)
+
+ torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
+
+
+# ---------------------------------------------------------------------------
+# 3. fused_cosine_similarity
+# ---------------------------------------------------------------------------
+class TestFusedCosineSimilarity:
+ def test_basic_correctness(self):
+ N, D = 64, 256
+ a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
+ b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
+
+ ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
+ out = fused_cosine_similarity(a, b)
+
+ torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
+
+ def test_batched(self):
+ """Test with (B, N, NB, D) shaped input."""
+ B, N, NB, D = 2, 4, 16, 512
+ a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
+ b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
+
+ ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
+ out = fused_cosine_similarity(a, b)
+
+ torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
+
+ def test_identical_vectors(self):
+ """Identical vectors should give similarity = 1."""
+ N, D = 16, 128
+ a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
+
+ out = fused_cosine_similarity(a, a)
+ torch.testing.assert_close(
+ out,
+ torch.ones(N, device=DEVICE, dtype=torch.float32),
+ atol=1e-5,
+ rtol=1e-5,
+ )
+
+ def test_orthogonal_vectors(self):
+ """Orthogonal vectors should give similarity = 0."""
+ D = 128
+ a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
+ b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
+ a[0, 0] = 1.0
+ b[0, 1] = 1.0
+
+ out = fused_cosine_similarity(a, b)
+ assert abs(out.item()) < 1e-5
+
+ def test_opposite_vectors(self):
+ """Opposite vectors should give similarity = -1."""
+ N, D = 8, 64
+ a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
+ out = fused_cosine_similarity(a, -a)
+ torch.testing.assert_close(
+ out,
+ -torch.ones(N, device=DEVICE, dtype=torch.float32),
+ atol=1e-5,
+ rtol=1e-5,
+ )
+
+ def test_large_dimension(self):
+ """Test with large feature dimension (multi-layer concatenated features)."""
+ N, D = 32, 4608 # 3 layers * 1536 hidden
+ a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
+ b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
+
+ ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
+ out = fused_cosine_similarity(a, b)
+
+ torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
+
+
+# ---------------------------------------------------------------------------
+# 4. fused_diversity_penalty
+# ---------------------------------------------------------------------------
+class TestFusedDiversityPenalty:
+ def _reference(self, embeddings):
+ """PyTorch reference: bmm + mask diagonal + mean."""
+ B, N, D = embeddings.shape
+ sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
+ eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
+ sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
+ return sims.sum(dim=-1) / (N - 1)
+
+ def test_basic_correctness(self):
+ B, N, D = 4, 4, 256
+ emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
+
+ ref = self._reference(emb)
+ out = fused_diversity_penalty(emb)
+
+ torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
+
+ def test_two_samples(self):
+ """With n=2, diversity = dot(a, b) for each."""
+ B, D = 3, 128
+ emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
+
+ ref = self._reference(emb)
+ out = fused_diversity_penalty(emb)
+
+ torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
+
+ def test_identical_samples(self):
+ """All identical samples should give max diversity."""
+ B, N, D = 2, 4, 64
+ vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
+ emb = vec.expand(B, N, D).contiguous()
+
+ out = fused_diversity_penalty(emb)
+ # Should be ||vec||^2 for each (self-excluded mean of identical dot products)
+ expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
+ torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
+
+ def test_large(self):
+ """Test with realistic EBFT dimensions."""
+ B, N, D = 1, 4, 4608 # multi-layer features
+ emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
+
+ ref = self._reference(emb)
+ out = fused_diversity_penalty(emb)
+
+ torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
+
+ def test_single_sample_returns_zeros(self):
+ """N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
+ B, D = 3, 128
+ emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
+ out = fused_diversity_penalty(emb)
+ assert (out == 0).all(), "N=1 diversity should be exactly zero"
diff --git a/tests/test_ebft_strided_structured.py b/tests/test_ebft_strided_structured.py
new file mode 100644
index 000000000..e4ad946a0
--- /dev/null
+++ b/tests/test_ebft_strided_structured.py
@@ -0,0 +1,363 @@
+"""Tests for the EBFT strided structured dataset transform and data loading."""
+
+import pytest
+from datasets import Dataset
+from tokenizers import Tokenizer, models, pre_tokenizers
+from transformers import PreTrainedTokenizerFast
+
+from axolotl.prompt_strategies.ebft import load as load_ebft
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture
+def tokenizer():
+ """Create a simple word-level tokenizer — no network access needed."""
+ # Build a tiny vocab covering common test words
+ vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
+ words = (
+ "what is 2 + the answer 4 hello world goodbye bye hi short prompt "
+ "x write code print test some string metadata noise ok good python "
+ "sampling abc 123 def solve return this that"
+ ).split()
+ for w in words:
+ if w not in vocab:
+ vocab[w] = len(vocab)
+
+ backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
+ backend.pre_tokenizer = pre_tokenizers.Whitespace()
+
+ tok = PreTrainedTokenizerFast(
+ tokenizer_object=backend,
+ bos_token="[BOS]",
+ eos_token="[EOS]",
+ pad_token="[PAD]",
+ unk_token="[UNK]",
+ )
+ return tok
+
+
+@pytest.fixture
+def cfg():
+ return DictDefault({"sequence_len": 64})
+
+
+@pytest.fixture
+def transform_fn_and_kwargs(cfg):
+ result = load_ebft("ebft_strided_structured.transform", cfg)
+ assert result is not None, "Failed to load ebft_strided_structured transform"
+ transform_fn, map_kwargs = result
+ return transform_fn, map_kwargs
+
+
+class TestEBFTStridedStructuredTransform:
+ """Tests for the dataset transform function itself."""
+
+ def test_transform_loads(self, transform_fn_and_kwargs):
+ transform_fn, map_kwargs = transform_fn_and_kwargs
+ assert callable(transform_fn)
+ assert "remove_columns" in map_kwargs
+
+ def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
+ """Transform should signal removal of all original columns."""
+ _, map_kwargs = transform_fn_and_kwargs
+ assert map_kwargs["remove_columns"] == "__all__"
+
+ def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
+ """Prompt tokens get labels=-100, completion tokens get real labels."""
+ transform_fn, _ = transform_fn_and_kwargs
+ example = {"input": "what is 2 + 2", "output": "the answer is 4"}
+ result = transform_fn(example, tokenizer=tokenizer)
+
+ assert "input_ids" in result
+ assert "labels" in result
+ assert "attention_mask" in result
+ assert "prompt_length" in result
+
+ prompt_length = result["prompt_length"]
+ labels = result["labels"]
+ seq_len = len(result["input_ids"])
+
+ assert seq_len == 64, "Should be padded to sequence_len"
+ assert len(labels) == seq_len
+ assert prompt_length > 0
+
+ # Prompt tokens should be masked
+ for i in range(prompt_length):
+ assert labels[i] == -100, f"Prompt token at {i} should be -100"
+
+ # At least one completion token should have a real label
+ completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
+ assert len(completion_labels) > 0, "Should have non-masked completion tokens"
+
+ def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
+ """prompt_length should be the exact boundary between -100 and real labels."""
+ transform_fn, _ = transform_fn_and_kwargs
+ example = {"input": "hello world", "output": "goodbye world"}
+ result = transform_fn(example, tokenizer=tokenizer)
+
+ prompt_length = result["prompt_length"]
+ labels = result["labels"]
+
+ assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
+ assert labels[prompt_length] != -100, (
+ "First completion token should not be masked"
+ )
+
+ def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
+ """Padding tokens should have labels=-100 and attention_mask=0."""
+ transform_fn, _ = transform_fn_and_kwargs
+ example = {"input": "hi", "output": "bye"}
+ result = transform_fn(example, tokenizer=tokenizer)
+
+ attention_mask = result["attention_mask"]
+ labels = result["labels"]
+
+ real_len = sum(attention_mask)
+ assert real_len < 64, "Short example should have padding"
+
+ for i in range(real_len, 64):
+ assert attention_mask[i] == 0, (
+ f"Pad position {i} should have attention_mask=0"
+ )
+ assert labels[i] == -100, f"Pad position {i} should have labels=-100"
+
+ def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
+ """Long completions should be truncated to fit sequence_len."""
+ transform_fn, _ = transform_fn_and_kwargs
+ example = {
+ "input": "short prompt",
+ "output": "x " * 500,
+ }
+ result = transform_fn(example, tokenizer=tokenizer)
+ assert len(result["input_ids"]) == 64
+
+ def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
+ """Transform should handle different field name conventions."""
+ transform_fn, _ = transform_fn_and_kwargs
+
+ result = transform_fn(
+ {"prompt": "what", "completion": "this"}, tokenizer=tokenizer
+ )
+ assert result["prompt_length"] > 0
+
+ result = transform_fn(
+ {"question": "what", "answer": "this"}, tokenizer=tokenizer
+ )
+ assert result["prompt_length"] > 0
+
+ def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
+ """Without tokenizer, should return a prompt string."""
+ transform_fn, _ = transform_fn_and_kwargs
+ result = transform_fn({"input": "hello", "output": "world"})
+ assert "prompt" in result
+ assert result["prompt"] == "hello"
+
+
+class TestEBFTColumnRemoval:
+ """Tests for the __all__ column removal logic in the RL data path."""
+
+ def _filter_remove_columns(self, map_kwargs, dataset):
+ """Reproduce the filtering logic from rl.py _load_split."""
+ if "remove_columns" in map_kwargs:
+ ds_columns = dataset.column_names
+ if map_kwargs["remove_columns"] == "__all__":
+ map_kwargs["remove_columns"] = list(ds_columns)
+ else:
+ map_kwargs["remove_columns"] = [
+ c for c in map_kwargs["remove_columns"] if c in ds_columns
+ ]
+ return map_kwargs
+
+ def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
+ """After mapping, only tokenized columns should remain."""
+ transform_fn, map_kwargs = transform_fn_and_kwargs
+ map_kwargs = dict(map_kwargs) # copy
+
+ ds = Dataset.from_list(
+ [
+ {"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
+ ]
+ )
+
+ map_kwargs = self._filter_remove_columns(map_kwargs, ds)
+ assert "input" in map_kwargs["remove_columns"]
+ assert "output" in map_kwargs["remove_columns"]
+ assert "extra_field" in map_kwargs["remove_columns"]
+
+ from functools import partial
+
+ mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
+ assert "input_ids" in mapped.column_names
+ assert "labels" in mapped.column_names
+ assert "prompt_length" in mapped.column_names
+ assert "input" not in mapped.column_names
+ assert "output" not in mapped.column_names
+ assert "extra_field" not in mapped.column_names
+
+ def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
+ """Datasets with many extra metadata columns should all be cleaned up."""
+ transform_fn, map_kwargs = transform_fn_and_kwargs
+ map_kwargs = dict(map_kwargs)
+
+ ds = Dataset.from_list(
+ [
+ {
+ "input": "write hello world",
+ "output": "print hello",
+ "id": "abc 123",
+ "domain": "python",
+ "generation_algorithm": "sampling",
+ "llm_judgement": "good",
+ "unit_tests": "test",
+ "tests_execution_status": "ok",
+ "average_test_score": 0.95,
+ },
+ ]
+ )
+
+ map_kwargs = self._filter_remove_columns(map_kwargs, ds)
+ assert len(map_kwargs["remove_columns"]) == 9
+
+ from functools import partial
+
+ mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
+
+ expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
+ assert set(mapped.column_names) == expected_columns
+
+ def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
+ """No string-typed columns should remain (would crash the DataLoader)."""
+ transform_fn, map_kwargs = transform_fn_and_kwargs
+ map_kwargs = dict(map_kwargs)
+
+ ds = Dataset.from_list(
+ [
+ {"input": "test", "output": "test", "notes": "some string metadata"},
+ ]
+ )
+
+ map_kwargs = self._filter_remove_columns(map_kwargs, ds)
+
+ from functools import partial
+
+ mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
+
+ for col in mapped.column_names:
+ feat = mapped.features[col]
+ assert str(feat) != "string", (
+ f"Column '{col}' is still a string — would crash DataLoader"
+ )
+
+ def test_filter_preserves_explicit_list(self):
+ """When remove_columns is an explicit list, only existing columns are kept."""
+ ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
+ map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
+
+ ds_columns = ds.column_names
+ map_kwargs["remove_columns"] = [
+ c for c in map_kwargs["remove_columns"] if c in ds_columns
+ ]
+
+ assert map_kwargs["remove_columns"] == ["a", "b"]
+ assert "missing_col" not in map_kwargs["remove_columns"]
+
+
+class TestMultiTurnSeparators:
+ """Verify multi-turn transforms and trainer-side GT reconstruction."""
+
+ def test_multiturn_transform_splits_turns(self):
+ """Transform should store first turn as GT and remaining turns separately."""
+ from axolotl.prompt_strategies.ebft import load as load_ebft
+ from axolotl.utils.dict import DictDefault
+
+ cfg = DictDefault({"sequence_len": 512})
+ fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
+ out = fn(
+ {
+ "messages": [
+ {"role": "user", "content": "Q1"},
+ {"role": "assistant", "content": "A1"},
+ {"role": "user", "content": "Q2"},
+ {"role": "assistant", "content": "A2"},
+ ]
+ }
+ )
+ # ground_truth is only the first assistant turn
+ assert out["ground_truth"] == "A1"
+ # remaining_turns carries the rest for trainer-side reconstruction
+ assert out["remaining_turns"] == [
+ {"role": "user", "content": "Q2"},
+ {"role": "assistant", "content": "A2"},
+ ]
+
+ def test_multiturn_gt_reconstruction_via_chat_template(self):
+ """Trainer-side GT reconstruction should insert role markers between turns.
+
+ This tests the logic from trainer.py:284-299 that reconstructs multi-turn
+ GT using apply_chat_template, ensuring assistant turns are separated by
+ role markers rather than concatenated as raw text.
+ """
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ "Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
+ )
+
+ # Simulate the transform output
+ prompt_msgs = [{"role": "user", "content": "Q1"}]
+ gt = "A1"
+ remaining_turns = [
+ {"role": "user", "content": "Q2"},
+ {"role": "assistant", "content": "A2"},
+ ]
+
+ # --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
+ prompt_text = tokenizer.apply_chat_template(
+ prompt_msgs, tokenize=False, add_generation_prompt=True
+ )
+ gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
+ gt_conv.extend(remaining_turns)
+ full_gt_text = tokenizer.apply_chat_template(
+ gt_conv, tokenize=False, add_generation_prompt=False
+ )
+
+ # The full GT text should contain both assistant turns with role markers
+ assert "A1" in full_gt_text
+ assert "A2" in full_gt_text
+ # Raw concatenation "A1A2" should NOT appear — role markers separate them
+ assert "A1A2" not in full_gt_text, (
+ "GT reconstruction should have role markers between turns, not raw concatenation"
+ )
+ # The user turn Q2 should appear between A1 and A2
+ a1_pos = full_gt_text.index("A1")
+ a2_pos = full_gt_text.index("A2")
+ q2_pos = full_gt_text.index("Q2")
+ assert a1_pos < q2_pos < a2_pos, (
+ "Turn order should be A1 -> Q2 -> A2 in rendered GT"
+ )
+ # The GT should start with the prompt
+ assert full_gt_text.startswith(prompt_text), (
+ "Full GT should start with the rendered prompt"
+ )
+
+ def test_multiturn_gt_reconstruction_fallback_single_turn(self):
+ """Single-turn prompts in a multi-turn dataset should use raw concatenation."""
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ "Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
+ )
+
+ prompt_msgs = [{"role": "user", "content": "Q1"}]
+ gt = "A1"
+ # remaining_turns would be [] for single-turn prompts
+
+ prompt_text = tokenizer.apply_chat_template(
+ prompt_msgs, tokenize=False, add_generation_prompt=True
+ )
+
+ # With empty remaining_turns, trainer falls through to raw concat
+ # (trainer.py:302: gt_texts.append(prompt_text + gt))
+ gt_text = prompt_text + gt
+ assert gt_text.endswith("A1")
+ assert prompt_text in gt_text
diff --git a/tests/test_http_weight_sync.py b/tests/test_http_weight_sync.py
new file mode 100644
index 000000000..97c5ece71
--- /dev/null
+++ b/tests/test_http_weight_sync.py
@@ -0,0 +1,158 @@
+"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
+
+Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
+the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
+"""
+
+import pytest
+import torch
+
+from axolotl.utils.weight_serde import (
+ decode_from_http,
+ decode_from_ipc,
+ encode_for_http,
+ encode_for_ipc,
+)
+
+# ---------------------------------------------------------------------------
+# Stage 1: trainer → serve endpoint (HTTP with base64)
+# ---------------------------------------------------------------------------
+
+
+class TestHttpEncodeRoundTrip:
+ """Test encode_for_http / decode_from_http."""
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
+ def test_round_trip_dtype(self, dtype):
+ original = torch.randn(32, 64, dtype=dtype)
+ entry = encode_for_http("layer.weight", original)
+ name, decoded = decode_from_http(entry)
+
+ assert name == "layer.weight"
+ assert decoded.dtype == dtype
+ assert decoded.shape == original.shape
+ if dtype == torch.bfloat16:
+ # bf16→fp16→bf16 loses some precision
+ torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
+ else:
+ torch.testing.assert_close(decoded, original, atol=0, rtol=0)
+
+ def test_bfloat16_wire_format_is_fp16(self):
+ """bf16 tensors should be sent as fp16 on the wire."""
+ import base64
+
+ original = torch.randn(8, 16, dtype=torch.bfloat16)
+ entry = encode_for_http("w", original)
+ raw = base64.b64decode(entry["data"])
+ # 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
+ assert len(raw) == 8 * 16 * 2
+ # dtype field should preserve original dtype for reconstruction
+ assert entry["dtype"] == "torch.bfloat16"
+
+ def test_multidimensional_shapes(self):
+ for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
+ original = torch.randn(*shape, dtype=torch.bfloat16)
+ entry = encode_for_http("w", original)
+ _, decoded = decode_from_http(entry)
+ assert decoded.shape == original.shape
+ assert decoded.dtype == torch.bfloat16
+
+
+# ---------------------------------------------------------------------------
+# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
+# ---------------------------------------------------------------------------
+
+
+class TestIpcEncodeRoundTrip:
+ """Test encode_for_ipc / decode_from_ipc."""
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
+ def test_round_trip_dtype(self, dtype):
+ original = torch.randn(32, 64, dtype=dtype)
+ entry = encode_for_ipc("layer.weight", original)
+ name, decoded = decode_from_ipc(entry)
+
+ assert name == "layer.weight"
+ assert decoded.dtype == dtype
+ assert decoded.shape == original.shape
+ if dtype == torch.bfloat16:
+ torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
+ else:
+ torch.testing.assert_close(decoded, original, atol=0, rtol=0)
+
+ def test_bfloat16_ipc_wire_is_fp16(self):
+ """bf16 tensors should be serialized as fp16 bytes in IPC."""
+ original = torch.randn(4, 8, dtype=torch.bfloat16)
+ entry = encode_for_ipc("w", original)
+ assert entry["dtype"] == "float16"
+ assert entry["target_dtype"] == "bfloat16"
+ assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
+
+ def test_fp32_has_no_target_dtype_mismatch(self):
+ original = torch.randn(4, 8, dtype=torch.float32)
+ entry = encode_for_ipc("w", original)
+ assert entry["dtype"] == "float32"
+ assert entry["target_dtype"] == "float32"
+
+ def test_worker_handles_missing_target_dtype(self):
+ """Backward compat: older serve code may not send target_dtype."""
+ entry = {
+ "name": "w",
+ "data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
+ "dtype": "float32",
+ "shape": [4, 8],
+ # no target_dtype key
+ }
+ name, decoded = decode_from_ipc(entry)
+ assert decoded.dtype == torch.float32
+ assert decoded.shape == (4, 8)
+
+
+# ---------------------------------------------------------------------------
+# Full pipeline: trainer → serve → worker
+# ---------------------------------------------------------------------------
+
+
+class TestFullPipelineRoundTrip:
+ """End-to-end: trainer → serve → worker."""
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
+ def test_three_stage_round_trip(self, dtype):
+ """Tensor survives trainer→serve→worker with correct dtype and values."""
+ original = torch.randn(16, 32, dtype=dtype)
+
+ # Stage 1: trainer encodes for HTTP
+ http_entry = encode_for_http("model.layers.0.weight", original)
+
+ # Stage 2: serve decodes HTTP, re-encodes for IPC
+ name, at_serve = decode_from_http(http_entry)
+ ipc_entry = encode_for_ipc(name, at_serve)
+
+ # Stage 3: worker decodes IPC
+ _, at_worker = decode_from_ipc(ipc_entry)
+
+ assert at_worker.dtype == dtype
+ assert at_worker.shape == original.shape
+ if dtype == torch.bfloat16:
+ # Two bf16→fp16→bf16 hops compound precision loss slightly
+ torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
+ else:
+ torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
+
+ def test_bfloat16_precision_loss_is_bounded(self):
+ """bf16→fp16→bf16 round-trip error should be small."""
+ original = torch.randn(256, 256, dtype=torch.bfloat16)
+ http_entry = encode_for_http("w", original)
+ _, at_serve = decode_from_http(http_entry)
+ ipc_entry = encode_for_ipc("w", at_serve)
+ _, at_worker = decode_from_ipc(ipc_entry)
+
+ max_err = (at_worker.float() - original.float()).abs().max().item()
+ # bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
+ assert max_err < 0.05, f"Max error {max_err} exceeds bound"
+
+ def test_bfloat16_numpy_would_crash_without_fix(self):
+ """Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
+ t = torch.randn(4, 4, dtype=torch.bfloat16)
+ with pytest.raises((RuntimeError, TypeError)):
+ t.numpy()