EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
@@ -22,6 +22,7 @@ RUN apt update && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
214
examples/ebft/README.md
Normal file
214
examples/ebft/README.md
Normal file
@@ -0,0 +1,214 @@
|
||||
# Energy-Based Fine-Tuning (EBFT)
|
||||
|
||||
EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
|
||||
|
||||
## Overview
|
||||
|
||||
EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
||||
|
||||
**Key advantages over SFT:**
|
||||
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
||||
- Provides dense sequence-level supervision without a task-specific verifier
|
||||
- Improves both downstream accuracy and validation cross-entropy simultaneously
|
||||
|
||||
**Key advantages over RLVR:**
|
||||
- No reward model or verifier required — works on any (prompt, completion) data
|
||||
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
|
||||
- Maintains distributional calibration (low feature-matching loss)
|
||||
|
||||
## Two Modes
|
||||
|
||||
EBFT supports two modes depending on your data format:
|
||||
|
||||
### Structured Mode (`mode: structured`, default)
|
||||
For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
|
||||
- Extends GRPOTrainer — uses vLLM for fast rollout generation
|
||||
- RLOO advantages and clipped policy gradient from GRPO
|
||||
- Feature-matching rewards replace external reward functions
|
||||
|
||||
### Strided Mode (`mode: strided`)
|
||||
For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
|
||||
- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
|
||||
- No vLLM needed — generation uses custom strided attention masks
|
||||
- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
|
||||
- Compatible with gradient checkpointing via automatic dtype normalization
|
||||
- This is the core EBFT algorithm from the paper (Section F)
|
||||
|
||||
### Common to both modes:
|
||||
- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
|
||||
- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
|
||||
- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
|
||||
- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
|
||||
- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
|
||||
- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Structured Mode (QA data + vLLM)
|
||||
|
||||
```bash
|
||||
# 1. Start vLLM server
|
||||
python -m trl.scripts.vllm_serve \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--host 0.0.0.0 --port 8000 \
|
||||
--gpu-memory-utilization 0.3
|
||||
|
||||
# 2. Train
|
||||
axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
|
||||
```
|
||||
|
||||
### Strided Mode (unstructured text)
|
||||
|
||||
```bash
|
||||
# No vLLM needed — strided generation is built-in
|
||||
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Common EBFT Settings
|
||||
|
||||
```yaml
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
# Feature network: which layers to extract hidden states from
|
||||
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
|
||||
# How to pool per-token hidden states into sequence embeddings
|
||||
# Options: "last_token" (recommended), "mean_pooling", "concat"
|
||||
embed_method: last_token
|
||||
|
||||
# SVD whitening — strongly recommended (paper shows largest degradation without it)
|
||||
use_whitening: true
|
||||
|
||||
# Reward = alignment_coef * alignment - diversity_coef * diversity
|
||||
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
|
||||
# diversity uses raw dot product — both are bounded after whitening.
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
|
||||
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
|
||||
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
|
||||
ce_coef: 0.0
|
||||
```
|
||||
|
||||
### Strided Mode Settings
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8 # tokens between anchor points (paper default: 8)
|
||||
context_length: 8 # context window per block (paper default: 8)
|
||||
generate_max_len: 8 # tokens generated per block (paper default: 8)
|
||||
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
|
||||
temperature: 0.6
|
||||
rl_coef: 1.0 # RL loss weight
|
||||
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
|
||||
```
|
||||
|
||||
### Structured Mode Settings (via TRL)
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt
|
||||
max_completion_length: 256 # max tokens to generate
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
```
|
||||
|
||||
### Dataset Format
|
||||
|
||||
**Structured mode** — QA data with prompt + ground-truth completion:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
```
|
||||
Transform returns: `{"prompt": ..., "ground_truth": ...}`
|
||||
|
||||
**Strided mode** — raw text tokenized to fixed length:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
```
|
||||
Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
|
||||
|
||||
## How It Works
|
||||
|
||||
### Structured Mode
|
||||
1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
|
||||
2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
|
||||
3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
|
||||
4. **RLOO advantages**: subtract leave-one-out group mean
|
||||
5. **Policy gradient**: clipped PPO-style loss
|
||||
|
||||
### Strided Mode
|
||||
1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
|
||||
2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
|
||||
3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
|
||||
4. **Per-block rewards**:
|
||||
- **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
|
||||
- **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
|
||||
- **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
|
||||
5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
|
||||
6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
|
||||
|
||||
### Tracked Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
|
||||
| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
|
||||
| `ebft/mean_reward` | alignment - diversity (should trend upward) |
|
||||
| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
|
||||
| `ebft/rl_loss` | REINFORCE policy gradient loss |
|
||||
| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
|
||||
| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
|
||||
|
||||
## Tips and Recommendations
|
||||
|
||||
### Reward coefficients
|
||||
- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
|
||||
- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
|
||||
- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
|
||||
- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
|
||||
|
||||
### Feature extraction
|
||||
- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
|
||||
- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
|
||||
|
||||
### Performance
|
||||
- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
|
||||
- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
|
||||
|
||||
### Memory
|
||||
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
|
||||
- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
|
||||
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
|
||||
- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
|
||||
|
||||
## Example Configs
|
||||
|
||||
| Config | Mode | Model | Description |
|
||||
|--------|------|-------|-------------|
|
||||
| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
|
||||
| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
|
||||
| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
|
||||
| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{jelassi2026matching,
|
||||
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
|
||||
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
|
||||
journal={arXiv preprint arXiv:2603.12248},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
28
examples/ebft/ebft_opencode.py
Normal file
28
examples/ebft/ebft_opencode.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Dataset transform for nvidia/OpenCodeInstruct with EBFT.
|
||||
|
||||
Maps the dataset's `input` (prompt) and `output` (code solution) fields
|
||||
to the format expected by the EBFT trainer.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "user", "content": example["input"]},
|
||||
],
|
||||
"ground_truth": example["output"],
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": [
|
||||
"id",
|
||||
"domain",
|
||||
"generation_algorithm",
|
||||
"llm_judgement",
|
||||
"unit_tests",
|
||||
"tests_execution_status",
|
||||
"average_test_score",
|
||||
]
|
||||
}
|
||||
31
examples/ebft/ebft_pretrain.py
Normal file
31
examples/ebft/ebft_pretrain.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Dataset transform for unstructured text data with strided EBFT.
|
||||
|
||||
Tokenizes raw text into fixed-length input_ids for the strided trainer.
|
||||
Sequences are padded to sequence_len for uniform batching.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
text = example.get("question", example.get("text", ""))
|
||||
if tokenizer is None:
|
||||
return {"prompt": text}
|
||||
|
||||
encoded = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
padding="max_length",
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"labels": list(encoded["input_ids"]),
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": ["question", "answer"]}
|
||||
80
examples/ebft/ebft_strided_structured.py
Normal file
80
examples/ebft/ebft_strided_structured.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
64
examples/ebft/llama-1b-ebft-opencode-novllm.yaml
Normal file
64
examples/ebft/llama-1b-ebft-opencode-novllm.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
# EBFT validation config — no vLLM, uses HF generate for simplicity
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 128
|
||||
temperature: 1.0
|
||||
use_vllm: false
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 512
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-validation
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_run_id:
|
||||
wandb_watch:
|
||||
wandb_log_model:
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
81
examples/ebft/llama-1b-ebft-opencode.yaml
Normal file
81
examples/ebft/llama-1b-ebft-opencode.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
|
||||
#
|
||||
# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
|
||||
# https://arxiv.org/abs/2603.12248
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM server on a separate GPU:
|
||||
# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
|
||||
# --model meta-llama/Llama-3.2-1B \
|
||||
# --host 0.0.0.0 --port 8000 \
|
||||
# --gpu-memory-utilization 0.4 --dtype bfloat16
|
||||
#
|
||||
# 2. Run training:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
|
||||
# --- Training method ---
|
||||
rl: ebft
|
||||
|
||||
# --- EBFT configuration ---
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
|
||||
embed_method: last_token # pool to sequence embedding via last token
|
||||
use_whitening: false # SVD whitening (disable for speed in small runs)
|
||||
alignment_coef: 1.0 # cosine similarity with ground-truth features
|
||||
diversity_coef: 1.0 # pairwise similarity penalty
|
||||
ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
|
||||
|
||||
# --- Generation settings (via TRL/GRPO infrastructure) ---
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt for RLOO
|
||||
max_completion_length: 256 # max generated tokens
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
# --- Dataset ---
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%] # first 1% for validation runs
|
||||
|
||||
# --- Training hyperparameters ---
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 50
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
weight_decay: 0.01
|
||||
|
||||
# --- LoRA (recommended to reduce memory with frozen feature network) ---
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# --- Hardware ---
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama-1b-opencode
|
||||
|
||||
# --- Logging ---
|
||||
use_tensorboard: true
|
||||
logging_steps: 1
|
||||
save_steps: 25
|
||||
65
examples/ebft/llama-1b-ebft-strided-structured.yaml
Normal file
65
examples/ebft/llama-1b-ebft-strided-structured.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
# EBFT Strided Structured Mode: For structured (prompt, completion) data
|
||||
# Uses strided block-parallel generation on completion spans — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.03 # small CE weight for structured data
|
||||
advantage_estimator: rloo
|
||||
min_completion_prefix: 8 # skip anchors too close to prompt boundary
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 2048
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
# max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
60
examples/ebft/llama-1b-ebft-strided.yaml
Normal file
60
examples/ebft/llama-1b-ebft-strided.yaml
Normal file
@@ -0,0 +1,60 @@
|
||||
# EBFT Strided Mode: For unstructured text data (raw code, prose)
|
||||
# Uses strided block-parallel generation — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:100]
|
||||
|
||||
sequence_len: 256
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 5
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-validation
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
69
examples/ebft/llama-3b-ebft-strided-fft.yaml
Normal file
69
examples/ebft/llama-3b-ebft-strided-fft.yaml
Normal file
@@ -0,0 +1,69 @@
|
||||
# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
|
||||
# Actor on GPU 0, frozen feature network on GPU 1
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
torch_dtype: bfloat16
|
||||
flash_attention: false
|
||||
gradient_checkpointing: true
|
||||
torch_compile: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
ddp: false
|
||||
device_map:
|
||||
"": 0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama3b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama3b-strided-lora-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
58
examples/ebft/llama-8b-ebft-strided-fft.yaml
Normal file
58
examples/ebft/llama-8b-ebft-strided-fft.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
|
||||
# Feature network is CPU-offloaded to fit in single 32GB GPU
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama8b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama8b-strided-fft-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
86
examples/ebft/qwen35-4b-ebft-structured-async.yaml
Normal file
86
examples/ebft/qwen35-4b-ebft-structured-async.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
async_prefetch: true
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured-async
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
77
examples/ebft/qwen35-4b-ebft-structured.yaml
Normal file
77
examples/ebft/qwen35-4b-ebft-structured.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
|
||||
# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
82
examples/ebft/qwen35-9b-ebft-structured.yaml
Normal file
82
examples/ebft/qwen35-9b-ebft-structured.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.7
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 3.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-9b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -38,18 +38,14 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
# Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
|
||||
# We default to axolotl's serve module instead of TRL's because TRL's sends
|
||||
# truncate_prompt_tokens which is unsupported in vLLM 0.17+.
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -79,6 +75,12 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
cli_enforce_eager = cli_args.get("enforce_eager")
|
||||
cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
|
||||
raw_enforce_eager = (
|
||||
cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
|
||||
)
|
||||
enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
|
||||
base_kwargs = dict(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
@@ -89,6 +91,7 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
@@ -98,6 +101,10 @@ def do_vllm_serve(
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
# Disable native LoRA in vLLM if not using vllm_lora_sync
|
||||
# (merged weight sync via batch_update doesn't need vLLM LoRA mode)
|
||||
if not getattr(cfg.trl, "vllm_lora_sync", False):
|
||||
lora_kwargs["enable_lora"] = False
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
|
||||
@@ -118,7 +118,7 @@ def load_preference_datasets(
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
@@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
|
||||
trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
|
||||
training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .ebft.strided import AxolotlStridedEBFTTrainer
|
||||
from .ebft.trainer import AxolotlEBFTTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
|
||||
213
src/axolotl/core/trainers/ebft/__init__.py
Normal file
213
src/axolotl/core/trainers/ebft/__init__.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""EBFT (Energy-Based Fine-Tuning) Strategy for training
|
||||
|
||||
Two modes:
|
||||
- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
|
||||
- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from axolotl.core.trainers.ebft.args import (
|
||||
AxolotlAsyncEBFTConfig,
|
||||
AxolotlEBFTConfig,
|
||||
AxolotlStridedEBFTConfig,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _get_ebft_mode(cfg: DictDefault) -> str:
|
||||
"""Determine EBFT mode from config."""
|
||||
if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
|
||||
return cfg.ebft.mode
|
||||
return "structured"
|
||||
|
||||
|
||||
class EBFTStrategy:
|
||||
"""Strategy for EBFT training — dispatches between structured and strided modes."""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
|
||||
|
||||
return AxolotlStridedEBFTTrainer
|
||||
|
||||
# Structured mode: async or sync
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
if use_async:
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
|
||||
|
||||
return AxolotlAsyncEBFTTrainer
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
|
||||
|
||||
return AxolotlEBFTTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return AxolotlStridedEBFTConfig
|
||||
|
||||
# Structured mode: async or sync config
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
if use_async:
|
||||
return AxolotlAsyncEBFTConfig
|
||||
return AxolotlEBFTConfig
|
||||
|
||||
@classmethod
|
||||
def is_strided(cls, cfg: DictDefault) -> bool:
|
||||
return _get_ebft_mode(cfg) == "strided"
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
"""Map axolotl YAML config fields to training args kwargs."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
mode = _get_ebft_mode(cfg)
|
||||
|
||||
# Common EBFT fields
|
||||
ebft = cfg.ebft
|
||||
if ebft:
|
||||
if ebft.feature_layers is not None:
|
||||
kwargs["ebft_feature_layers"] = ebft.feature_layers
|
||||
if ebft.embed_method is not None:
|
||||
kwargs["ebft_embed_method"] = ebft.embed_method
|
||||
if ebft.use_whitening is not None:
|
||||
kwargs["ebft_use_whitening"] = ebft.use_whitening
|
||||
if ebft.alignment_coef is not None:
|
||||
kwargs["ebft_alignment_coef"] = ebft.alignment_coef
|
||||
if ebft.diversity_coef is not None:
|
||||
kwargs["ebft_diversity_coef"] = ebft.diversity_coef
|
||||
if ebft.ce_coef is not None:
|
||||
kwargs["ebft_ce_coef"] = ebft.ce_coef
|
||||
if getattr(ebft, "adaptive_max_tokens", None) is not None:
|
||||
kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
|
||||
if getattr(ebft, "gt_length_multiplier", None) is not None:
|
||||
kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
|
||||
|
||||
if mode == "strided":
|
||||
# Strided-specific fields
|
||||
if ebft:
|
||||
if ebft.stride is not None:
|
||||
kwargs["ebft_stride"] = ebft.stride
|
||||
if ebft.context_length is not None:
|
||||
kwargs["ebft_context_length"] = ebft.context_length
|
||||
if ebft.generate_max_len is not None:
|
||||
kwargs["ebft_generate_max_len"] = ebft.generate_max_len
|
||||
if ebft.n_samples_per_prompt is not None:
|
||||
kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
|
||||
if ebft.temperature is not None:
|
||||
kwargs["ebft_temperature"] = ebft.temperature
|
||||
if ebft.top_p is not None:
|
||||
kwargs["ebft_top_p"] = ebft.top_p
|
||||
if ebft.rl_coef is not None:
|
||||
kwargs["ebft_rl_coef"] = ebft.rl_coef
|
||||
if ebft.advantage_estimator is not None:
|
||||
kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
|
||||
if ebft.min_completion_prefix is not None:
|
||||
kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
|
||||
else:
|
||||
# Structured mode: map TRL config fields
|
||||
trl = cfg.trl
|
||||
if trl:
|
||||
if trl.use_vllm:
|
||||
kwargs["use_vllm"] = trl.use_vllm
|
||||
if trl.vllm_mode:
|
||||
kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
|
||||
vllm_cfg = cfg.vllm
|
||||
if vllm_cfg:
|
||||
kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
kwargs["vllm_server_host"] = trl.vllm_server_host or (
|
||||
trl.vllm.host if trl.vllm else None
|
||||
)
|
||||
kwargs["vllm_server_port"] = trl.vllm_server_port or (
|
||||
trl.vllm.port if trl.vllm else None
|
||||
)
|
||||
if trl.vllm_server_timeout:
|
||||
kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
|
||||
if trl.num_generations:
|
||||
kwargs["num_generations"] = trl.num_generations
|
||||
if trl.max_completion_length is not None:
|
||||
kwargs["max_completion_length"] = trl.max_completion_length
|
||||
if trl.temperature is not None:
|
||||
kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
kwargs["min_p"] = trl.min_p
|
||||
if trl.num_iterations is not None:
|
||||
kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
kwargs["epsilon"] = trl.epsilon
|
||||
if trl.epsilon_high is not None:
|
||||
kwargs["epsilon_high"] = trl.epsilon_high
|
||||
if trl.scale_rewards is not None:
|
||||
kwargs["scale_rewards"] = trl.scale_rewards
|
||||
if trl.loss_type is not None:
|
||||
kwargs["loss_type"] = trl.loss_type
|
||||
if trl.mask_truncated_completions is not None:
|
||||
kwargs["mask_truncated_completions"] = (
|
||||
trl.mask_truncated_completions
|
||||
)
|
||||
if trl.log_completions is not None:
|
||||
kwargs["log_completions"] = trl.log_completions
|
||||
if trl.num_completions_to_print is not None:
|
||||
kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
if trl.sync_ref_model:
|
||||
kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
if trl.repetition_penalty is not None:
|
||||
kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
if trl.generation_kwargs is not None:
|
||||
kwargs["generation_kwargs"] = trl.generation_kwargs
|
||||
if trl.chat_template_kwargs is not None:
|
||||
kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
|
||||
|
||||
# Async prefetch fields (only pass when enabled — sync config doesn't have these)
|
||||
if getattr(trl, "async_prefetch", False):
|
||||
kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "vllm_lora_sync", False):
|
||||
kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"max_prompt_length",
|
||||
"include_tokens_per_second",
|
||||
"beta",
|
||||
]
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"include_tokens_per_second",
|
||||
"max_prompt_length",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs):
|
||||
return None
|
||||
133
src/axolotl/core/trainers/ebft/args.py
Normal file
133
src/axolotl/core/trainers/ebft/args.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
EBFT-specific training arguments.
|
||||
|
||||
Two config classes:
|
||||
- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
|
||||
- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
# -- Shared EBFT fields as a mixin --
|
||||
@dataclass
|
||||
class EBFTFieldsMixin:
|
||||
"""Common fields shared between structured and strided EBFT configs."""
|
||||
|
||||
ebft_feature_layers: list[float] = field(
|
||||
default_factory=lambda: [0.25, 0.5, 0.75],
|
||||
metadata={"help": "Fractional layer depths for feature extraction"},
|
||||
)
|
||||
ebft_embed_method: str = field(
|
||||
default="last_token",
|
||||
metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
|
||||
)
|
||||
ebft_use_whitening: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Apply SVD whitening to feature embeddings"},
|
||||
)
|
||||
ebft_alignment_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
|
||||
)
|
||||
ebft_diversity_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for diversity penalty"},
|
||||
)
|
||||
ebft_ce_coef: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
|
||||
)
|
||||
ebft_adaptive_max_tokens: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
|
||||
)
|
||||
ebft_gt_length_multiplier: float = field(
|
||||
default=1.5,
|
||||
metadata={
|
||||
"help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
|
||||
@dataclass
|
||||
class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
|
||||
"""EBFT config for structured QA data — extends GRPOConfig."""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Async structured mode: extends FastAsyncGRPOConfig --
|
||||
@dataclass
|
||||
class AxolotlAsyncEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
|
||||
):
|
||||
"""EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
|
||||
|
||||
Includes all async fields: async_prefetch, vllm_lora_sync,
|
||||
skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
|
||||
"""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Strided mode: extends TrainingArguments for unstructured text --
|
||||
@dataclass
|
||||
class AxolotlStridedEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
|
||||
):
|
||||
"""EBFT config for unstructured text with strided block-parallel generation."""
|
||||
|
||||
ebft_stride: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Stride between anchor points (in tokens)"},
|
||||
)
|
||||
ebft_context_length: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Context window size for each block"},
|
||||
)
|
||||
ebft_generate_max_len: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of tokens to generate per block"},
|
||||
)
|
||||
ebft_n_samples_per_prompt: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of independent rollouts per document"},
|
||||
)
|
||||
ebft_temperature: float = field(
|
||||
default=0.6,
|
||||
metadata={"help": "Sampling temperature for strided generation"},
|
||||
)
|
||||
ebft_top_p: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Top-p nucleus sampling threshold"},
|
||||
)
|
||||
ebft_rl_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "RL policy gradient loss coefficient"},
|
||||
)
|
||||
ebft_advantage_estimator: str = field(
|
||||
default="rloo",
|
||||
metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
|
||||
)
|
||||
ebft_min_completion_prefix: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Minimum tokens into completion before placing anchors"},
|
||||
)
|
||||
308
src/axolotl/core/trainers/ebft/kernels.py
Normal file
308
src/axolotl/core/trainers/ebft/kernels.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Fused Triton kernels for strided EBFT.
|
||||
|
||||
These kernels eliminate intermediate tensor materializations that dominate
|
||||
the elementwise/fill category (~40% of CUDA time in profiling).
|
||||
|
||||
Kernels:
|
||||
1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
|
||||
2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
|
||||
3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Fused log_softmax + gather (selective log softmax)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
|
||||
# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
|
||||
# This avoids materializing the full (B, S, V) log_softmax output.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_log_softmax_gather_kernel(
|
||||
logits_ptr, # (B*S, V) row-major
|
||||
labels_ptr, # (B*S,) int64
|
||||
output_ptr, # (B*S,) float32
|
||||
V: tl.constexpr, # vocab size
|
||||
BLOCK_V: tl.constexpr, # tile width over vocab
|
||||
):
|
||||
"""Compute log_softmax(logits)[label] for each row without materializing full output."""
|
||||
row = tl.program_id(0)
|
||||
|
||||
logits_row_ptr = logits_ptr + row * V
|
||||
label = tl.load(labels_ptr + row)
|
||||
|
||||
# Pass 1: find max for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
max_val = tl.maximum(max_val, tl.max(vals, axis=0))
|
||||
|
||||
# Pass 2: compute sum(exp(x - max))
|
||||
sum_exp = 0.0
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
|
||||
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Gather: log_softmax[label] = logits[label] - max - log_sum_exp
|
||||
target_logit = tl.load(logits_row_ptr + label)
|
||||
result = target_logit - max_val - log_sum_exp
|
||||
|
||||
tl.store(output_ptr + row, result)
|
||||
|
||||
|
||||
def fused_log_softmax_gather(
|
||||
logits: torch.Tensor, labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
|
||||
|
||||
Args:
|
||||
logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
|
||||
labels: (B, S) or (B*S,) int64 tensor of target indices
|
||||
|
||||
Returns:
|
||||
(B, S) or (B*S,) float32 tensor of selected log probabilities
|
||||
"""
|
||||
orig_shape = logits.shape[:-1]
|
||||
V = logits.shape[-1]
|
||||
logits_2d = logits.reshape(-1, V).contiguous()
|
||||
labels_1d = labels.reshape(-1).contiguous()
|
||||
n_rows = logits_2d.shape[0]
|
||||
|
||||
output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
|
||||
|
||||
# Choose BLOCK_V: must be power of 2, large enough for good occupancy
|
||||
BLOCK_V = min(triton.next_power_of_2(V), 65536)
|
||||
|
||||
_fused_log_softmax_gather_kernel[(n_rows,)](
|
||||
logits_2d,
|
||||
labels_1d,
|
||||
output,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Fused masked REINFORCE loss reduction
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: (-logp * adv * mask).sum() / mask.sum()
|
||||
# We do the masked multiply-accumulate in one kernel, returning (sum, count).
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_reinforce_loss_kernel(
|
||||
logps_ptr, # (N,) float32 per-token log probs
|
||||
advs_ptr, # (N,) float32 advantages
|
||||
mask_ptr, # (N,) bool action mask
|
||||
partial_sum_ptr, # (n_blocks,) partial sums
|
||||
partial_cnt_ptr, # (n_blocks,) partial counts
|
||||
N: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
block_id = tl.program_id(0)
|
||||
offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
valid = offsets < N
|
||||
|
||||
logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
|
||||
advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
|
||||
m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
|
||||
|
||||
# -logp * advantage * mask
|
||||
loss = -logps * advs * m
|
||||
block_sum = tl.sum(loss, axis=0)
|
||||
block_cnt = tl.sum(m, axis=0)
|
||||
|
||||
tl.store(partial_sum_ptr + block_id, block_sum)
|
||||
tl.store(partial_cnt_ptr + block_id, block_cnt)
|
||||
|
||||
|
||||
def fused_reinforce_loss(
|
||||
per_token_logps: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
|
||||
|
||||
All inputs should be flat or will be flattened. Returns scalar loss tensor.
|
||||
"""
|
||||
logps_flat = per_token_logps.reshape(-1).contiguous()
|
||||
advs_flat = advantages.reshape(-1).contiguous()
|
||||
mask_flat = action_mask.reshape(-1).contiguous()
|
||||
N = logps_flat.shape[0]
|
||||
|
||||
BLOCK_N = 1024
|
||||
n_blocks = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
|
||||
_fused_reinforce_loss_kernel[(n_blocks,)](
|
||||
logps_flat,
|
||||
advs_flat,
|
||||
mask_flat,
|
||||
partial_sum,
|
||||
partial_cnt,
|
||||
N=N,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
total_sum = partial_sum.sum()
|
||||
total_cnt = partial_cnt.sum().clamp(min=1)
|
||||
return total_sum / total_cnt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Fused cosine similarity (batched, for EBFT rewards)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
|
||||
# we fuse the dot product, norm computation, and division into one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_cosine_sim_kernel(
|
||||
a_ptr, # (N, D) first set of vectors
|
||||
b_ptr, # (N, D) second set of vectors
|
||||
out_ptr, # (N,) cosine similarities
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
a_row_ptr = a_ptr + row * D
|
||||
b_row_ptr = b_ptr + row * D
|
||||
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
mask = d_offsets < D
|
||||
a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
norm_a += tl.sum(a_vals * a_vals, axis=0)
|
||||
norm_b += tl.sum(b_vals * b_vals, axis=0)
|
||||
|
||||
denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
|
||||
denom = tl.where(denom > 1e-8, denom, 1e-8)
|
||||
result = dot / denom
|
||||
|
||||
tl.store(out_ptr + row, result)
|
||||
|
||||
|
||||
def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cosine similarity along the last dimension.
|
||||
|
||||
Args:
|
||||
a, b: (..., D) tensors of the same shape
|
||||
|
||||
Returns:
|
||||
(...,) tensor of cosine similarities
|
||||
"""
|
||||
orig_shape = a.shape[:-1]
|
||||
D = a.shape[-1]
|
||||
a_2d = a.reshape(-1, D).contiguous()
|
||||
b_2d = b.reshape(-1, D).contiguous()
|
||||
N = a_2d.shape[0]
|
||||
|
||||
output = torch.empty(N, device=a.device, dtype=torch.float32)
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_cosine_sim_kernel[(N,)](
|
||||
a_2d,
|
||||
b_2d,
|
||||
output,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Fused pairwise diversity penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
|
||||
# We compute the pairwise dot products and exclusion in one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_diversity_kernel(
|
||||
emb_ptr, # (B, N, D) embeddings, row-major
|
||||
out_ptr, # (B, N) diversity penalties
|
||||
N: tl.constexpr, # n_samples
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
"""For each (b, i), compute mean dot product to all j != i."""
|
||||
b = tl.program_id(0)
|
||||
i = tl.program_id(1)
|
||||
|
||||
# Pointer to emb[b, i, :]
|
||||
emb_bi_ptr = emb_ptr + (b * N + i) * D
|
||||
|
||||
total_sim = 0.0
|
||||
for j in range(N):
|
||||
emb_bj_ptr = emb_ptr + (b * N + j) * D
|
||||
|
||||
dot = 0.0
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offsets < D
|
||||
a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
|
||||
# Exclude self-similarity (j == i)
|
||||
is_other = j != i
|
||||
total_sim += dot * is_other
|
||||
|
||||
result = total_sim / (N - 1)
|
||||
tl.store(out_ptr + b * N + i, result)
|
||||
|
||||
|
||||
def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute mean pairwise dot product (excluding self) per sample.
|
||||
|
||||
Args:
|
||||
embeddings: (B, N, D) tensor where N is n_samples
|
||||
|
||||
Returns:
|
||||
(B, N) tensor of diversity penalties
|
||||
"""
|
||||
B, N, D = embeddings.shape
|
||||
embeddings = embeddings.contiguous()
|
||||
output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
|
||||
if N <= 1:
|
||||
return output # diversity is 0 when there's only one sample
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_diversity_kernel[(B, N)](
|
||||
embeddings,
|
||||
output,
|
||||
N=N,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output
|
||||
264
src/axolotl/core/trainers/ebft/rewards.py
Normal file
264
src/axolotl/core/trainers/ebft/rewards.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
|
||||
|
||||
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_hidden_states(
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_indices: list[int],
|
||||
batch_size: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through model, extracting and concatenating hidden states
|
||||
at specified layer indices.
|
||||
|
||||
Args:
|
||||
model: The frozen feature network
|
||||
input_ids: (B, S) token ids
|
||||
attention_mask: (B, S) attention mask
|
||||
layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
|
||||
batch_size: If set, process in chunks to reduce peak memory
|
||||
|
||||
Returns:
|
||||
Concatenated hidden states: (B, S, num_layers * H)
|
||||
"""
|
||||
if batch_size is None:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Use the inner transformer body (skips lm_head) when available.
|
||||
# This avoids the expensive hidden_dim × vocab_size matmul whose
|
||||
# output (logits) is never used — only hidden_states are needed.
|
||||
body = getattr(model, "model", None)
|
||||
if body is not None and hasattr(body, "forward"):
|
||||
forward_model = body
|
||||
else:
|
||||
forward_model = model
|
||||
|
||||
all_features = []
|
||||
for i in range(0, input_ids.shape[0], batch_size):
|
||||
chunk_ids = input_ids[i : i + batch_size]
|
||||
chunk_mask = attention_mask[i : i + batch_size]
|
||||
|
||||
outputs = forward_model(
|
||||
chunk_ids,
|
||||
attention_mask=chunk_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
|
||||
# index 0 is the embedding layer output
|
||||
hidden_states = outputs.hidden_states
|
||||
chunk_features = []
|
||||
for idx in layer_indices:
|
||||
chunk_features.append(hidden_states[idx])
|
||||
|
||||
# Concatenate across feature dimension: (B, S, num_layers * H)
|
||||
all_features.append(torch.cat(chunk_features, dim=-1))
|
||||
|
||||
return torch.cat(all_features, dim=0)
|
||||
|
||||
|
||||
def apply_embed_method(
|
||||
hidden_states: torch.Tensor,
|
||||
method: str,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
prompt_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pool per-token hidden states into per-sequence embeddings.
|
||||
|
||||
Args:
|
||||
hidden_states: (B, S, D) concatenated hidden states
|
||||
method: One of "last_token", "mean_pooling", "completion_mean", "concat"
|
||||
attention_mask: (B, S) mask for mean pooling
|
||||
prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
|
||||
|
||||
Returns:
|
||||
Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
|
||||
(B, 3*D) for concat
|
||||
"""
|
||||
if method == "last_token":
|
||||
if attention_mask is not None:
|
||||
# Find last non-padding position per sample
|
||||
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
|
||||
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
|
||||
return hidden_states[:, -1, :]
|
||||
|
||||
if method == "mean_pooling":
|
||||
if attention_mask is not None:
|
||||
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
return hidden_states.mean(dim=1)
|
||||
|
||||
if method == "completion_mean":
|
||||
# Mean pool over completion tokens only (exclude prompt)
|
||||
if prompt_lengths is None:
|
||||
raise ValueError("completion_mean requires prompt_lengths")
|
||||
B, S, _ = hidden_states.shape
|
||||
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
|
||||
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
|
||||
if attention_mask is not None:
|
||||
comp_mask = comp_mask & attention_mask.bool()
|
||||
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
|
||||
if method == "concat":
|
||||
B, S, D = hidden_states.shape
|
||||
if attention_mask is not None:
|
||||
valid_lens = attention_mask.sum(dim=1).long() # (B,)
|
||||
else:
|
||||
valid_lens = torch.full(
|
||||
(B,), S, device=hidden_states.device, dtype=torch.long
|
||||
)
|
||||
# Compute quartile positions relative to valid length per sample
|
||||
# First valid position index for each sample (handles right-padding)
|
||||
q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
|
||||
q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
batch_idx = torch.arange(B, device=hidden_states.device)
|
||||
return torch.cat(
|
||||
[
|
||||
hidden_states[batch_idx, q1],
|
||||
hidden_states[batch_idx, q2],
|
||||
hidden_states[batch_idx, q3],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown embed_method: {method}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_alignment_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
gt_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute alignment reward as cosine similarity between generated
|
||||
and ground-truth feature embeddings.
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated sequence embeddings
|
||||
gt_embedding: (B, D) ground-truth sequence embeddings
|
||||
If num_generations > 1, gt_embedding should be repeated
|
||||
to match gen_embedding's batch dim.
|
||||
|
||||
Returns:
|
||||
Alignment rewards: (B,) cosine similarities in [-1, 1]
|
||||
"""
|
||||
return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diversity_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
num_generations: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute diversity penalty as mean pairwise dot-product similarity
|
||||
between samples from the same prompt (excluding self-similarity).
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
|
||||
num_generations: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Diversity penalties: (B,) mean similarity to other samples from same prompt
|
||||
"""
|
||||
if num_generations <= 1:
|
||||
return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
|
||||
|
||||
num_prompts = gen_embedding.shape[0] // num_generations
|
||||
|
||||
# Reshape to (num_prompts, num_generations, D)
|
||||
reshaped = gen_embedding.view(num_prompts, num_generations, -1)
|
||||
|
||||
# Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
|
||||
sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
|
||||
|
||||
# Zero out self-similarity (diagonal)
|
||||
eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
|
||||
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
|
||||
|
||||
# Mean similarity to other samples: (num_prompts, num_generations)
|
||||
diversity = sims.sum(dim=-1) / (num_generations - 1)
|
||||
|
||||
# Flatten back to (B,)
|
||||
return diversity.view(-1)
|
||||
|
||||
|
||||
def whiten_embeddings_batched(
|
||||
phi: torch.Tensor,
|
||||
phi_gt: torch.Tensor,
|
||||
whiten_tol: float = 1e-5,
|
||||
normalize: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Whiten generated embeddings using SVD, then apply same transform to ground-truth.
|
||||
|
||||
Whitening decorrelates feature dimensions so no single direction dominates
|
||||
the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
|
||||
|
||||
Note: Singular values scale with sqrt(B), so reward magnitudes are
|
||||
batch-size dependent. This is acceptable because B = n_samples_per_prompt
|
||||
which is fixed during training (typically 2-4).
|
||||
|
||||
Args:
|
||||
phi: (B, D) generated embeddings (used to estimate covariance)
|
||||
phi_gt: (B, D) ground-truth embeddings
|
||||
whiten_tol: Tolerance for singular value cutoff
|
||||
normalize: If True, L2-normalize after whitening
|
||||
|
||||
Returns:
|
||||
Whitened (phi, phi_gt) tuple, each (B, D)
|
||||
"""
|
||||
phi_f = phi.float()
|
||||
phi_gt_f = phi_gt.float()
|
||||
|
||||
# Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
|
||||
except torch._C._LinAlgError:
|
||||
# Fallback: add small noise
|
||||
noise = 1e-6 * phi_f.abs().mean()
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(
|
||||
(phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
|
||||
full_matrices=False,
|
||||
)
|
||||
except torch._C._LinAlgError:
|
||||
if normalize:
|
||||
return (
|
||||
F.normalize(phi, p=2, dim=-1),
|
||||
F.normalize(phi_gt, p=2, dim=-1),
|
||||
)
|
||||
return phi, phi_gt
|
||||
|
||||
U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
|
||||
|
||||
# Safe inverse of singular values
|
||||
s_max = S.max()
|
||||
inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
|
||||
|
||||
# W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
|
||||
W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
|
||||
phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
|
||||
phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
|
||||
|
||||
if normalize:
|
||||
phi_w = F.normalize(phi_w, p=2, dim=-1)
|
||||
phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
|
||||
|
||||
return phi_w, phi_gt_w
|
||||
1152
src/axolotl/core/trainers/ebft/strided.py
Normal file
1152
src/axolotl/core/trainers/ebft/strided.py
Normal file
File diff suppressed because it is too large
Load Diff
531
src/axolotl/core/trainers/ebft/trainer.py
Normal file
531
src/axolotl/core/trainers/ebft/trainer.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
|
||||
|
||||
Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
|
||||
the standard GRPO reward function interface.
|
||||
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
|
||||
|
||||
from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
|
||||
from axolotl.core.trainers.ebft.rewards import (
|
||||
apply_embed_method,
|
||||
extract_hidden_states,
|
||||
get_alignment_rewards,
|
||||
get_diversity_rewards,
|
||||
whiten_embeddings_batched,
|
||||
)
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
from accelerate import Accelerator
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EBFTMixin:
|
||||
"""
|
||||
Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
|
||||
|
||||
Provides:
|
||||
- Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
|
||||
- _feature_matching_reward() callable for GRPO reward function interface
|
||||
- _sequential_rollout() for multi-turn conversations
|
||||
"""
|
||||
|
||||
# Type stubs for attributes provided by the composed GRPOTrainer base class.
|
||||
# These are not defined here but accessed via cooperative multiple inheritance.
|
||||
if TYPE_CHECKING:
|
||||
accelerator: Accelerator
|
||||
model: PreTrainedModel
|
||||
args: AxolotlEBFTConfig
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
num_generations: int
|
||||
vllm_generation: VLLMGeneration
|
||||
_metrics: defaultdict
|
||||
|
||||
_tag_names = ["trl", "ebft", "axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | PreTrainedModel,
|
||||
args: AxolotlEBFTConfig | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset
|
||||
| IterableDataset
|
||||
| dict[str, Dataset | IterableDataset]
|
||||
| None = None,
|
||||
processing_class: PreTrainedTokenizerBase | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[
|
||||
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
|
||||
] = (None, None),
|
||||
peft_config: Any | None = None,
|
||||
):
|
||||
# Pass our feature-matching reward function to GRPOTrainer
|
||||
# It will be called with (prompts, completions, **kwargs) where
|
||||
# kwargs includes all extra dataset fields like "ground_truth"
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
model=model,
|
||||
reward_funcs=[self._feature_matching_reward],
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
assert args is not None
|
||||
|
||||
# --- Feature network setup ---
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
# Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
|
||||
self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
|
||||
unwrapped, "disable_adapter"
|
||||
)
|
||||
|
||||
if self._share_feature_weights:
|
||||
# Share weights: use actor's base model with adapters disabled.
|
||||
# Saves a full model copy (~8 GB for 4B model).
|
||||
self.feature_network = None
|
||||
param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
|
||||
LOG.info(
|
||||
f"EBFT feature network shares actor weights (PEFT disable_adapter). "
|
||||
f"Saving ~{param_gb:.1f} GB"
|
||||
)
|
||||
else:
|
||||
LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
|
||||
self.feature_network = copy.deepcopy(unwrapped)
|
||||
for param in self.feature_network.parameters():
|
||||
param.requires_grad = False
|
||||
self.feature_network.eval()
|
||||
|
||||
# Compute layer indices from fractional depths
|
||||
# Handle VLM models where num_hidden_layers is on text_config
|
||||
config = unwrapped.config
|
||||
if hasattr(config, "text_config") and hasattr(
|
||||
config.text_config, "num_hidden_layers"
|
||||
):
|
||||
config = config.text_config
|
||||
num_layers = config.num_hidden_layers
|
||||
self.feature_layer_indices = [
|
||||
int(frac * num_layers) for frac in args.ebft_feature_layers
|
||||
]
|
||||
LOG.info(
|
||||
f"EBFT feature extraction from layers {self.feature_layer_indices} "
|
||||
f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
|
||||
)
|
||||
if args.ebft_adaptive_max_tokens:
|
||||
LOG.info(
|
||||
f"EBFT adaptive max_tokens enabled "
|
||||
f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
|
||||
)
|
||||
|
||||
_adaptive_max_lock = None # initialized lazily
|
||||
|
||||
def _generate_only(self, inputs, rank0_only=False):
|
||||
"""Override to set per-batch max_tokens based on ground-truth length.
|
||||
|
||||
Uses a lock to prevent race conditions in async mode where concurrent
|
||||
BG threads could interleave mutations of max_completion_length.
|
||||
"""
|
||||
import threading
|
||||
|
||||
args = self.args
|
||||
if (
|
||||
args.ebft_adaptive_max_tokens
|
||||
and hasattr(self, "vllm_generation")
|
||||
and inputs
|
||||
):
|
||||
gt_texts = [
|
||||
x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
|
||||
]
|
||||
if gt_texts:
|
||||
gt_token_counts = [
|
||||
len(self.processing_class.encode(gt, add_special_tokens=False))
|
||||
for gt in gt_texts
|
||||
]
|
||||
multiplier = args.ebft_gt_length_multiplier
|
||||
max_completion = self.vllm_generation.max_completion_length
|
||||
adaptive_max = max(
|
||||
min(int(c * multiplier), max_completion) for c in gt_token_counts
|
||||
)
|
||||
adaptive_max = max(adaptive_max, 64)
|
||||
|
||||
if self._adaptive_max_lock is None:
|
||||
self._adaptive_max_lock = threading.Lock()
|
||||
with self._adaptive_max_lock:
|
||||
original = self.vllm_generation.max_completion_length
|
||||
self.vllm_generation.max_completion_length = adaptive_max
|
||||
try:
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
finally:
|
||||
self.vllm_generation.max_completion_length = original
|
||||
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
|
||||
@torch.no_grad()
|
||||
def _feature_matching_reward(
|
||||
self,
|
||||
prompts: list,
|
||||
completions: list,
|
||||
ground_truth: list[str] | None = None,
|
||||
remaining_turns: list | None = None,
|
||||
**kwargs,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Compute feature-matching rewards for generated completions.
|
||||
|
||||
This is called by GRPOTrainer's _generate_and_score_completions()
|
||||
as a standard reward function. The `ground_truth` field comes from
|
||||
the dataset via reward_kwargs.
|
||||
|
||||
For multi-turn conversations, `remaining_turns` contains the subsequent
|
||||
user/assistant turn pairs. When present, we do sequential rollouts:
|
||||
generate each assistant turn conditioned on history + previous generations,
|
||||
then compute feature-matching rewards on the full generated conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt strings/messages
|
||||
completions: List of generated completion strings
|
||||
ground_truth: List of reference completion strings (from dataset)
|
||||
remaining_turns: List of remaining conversation turns after the
|
||||
first assistant turn (for multi-turn rollouts)
|
||||
|
||||
Returns:
|
||||
List of scalar rewards, one per completion
|
||||
"""
|
||||
if ground_truth is None:
|
||||
LOG.warning("No ground_truth field in dataset — using zero rewards")
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
device = self.accelerator.device
|
||||
args = self.args
|
||||
num_gens = self.num_generations
|
||||
|
||||
# --- Multi-turn sequential rollout ---
|
||||
# If remaining_turns is provided, generate subsequent assistant turns
|
||||
# by calling vLLM for each turn, building up the full conversation.
|
||||
if remaining_turns is not None and hasattr(self, "vllm_generation"):
|
||||
completions = self._sequential_rollout(
|
||||
prompts, completions, remaining_turns, num_gens
|
||||
)
|
||||
|
||||
# --- Tokenize generated sequences: prompt + completion ---
|
||||
gen_texts = []
|
||||
gen_prompt_texts = []
|
||||
for p, c in zip(prompts, completions, strict=True):
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
prompt_text = p
|
||||
if isinstance(c, list):
|
||||
comp_text = c[0].get("content", "") if c else ""
|
||||
else:
|
||||
comp_text = c
|
||||
gen_texts.append(prompt_text + comp_text)
|
||||
gen_prompt_texts.append(prompt_text)
|
||||
|
||||
gen_encoded = self.processing_class(
|
||||
text=gen_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gen_ids = gen_encoded["input_ids"].to(device)
|
||||
gen_mask = gen_encoded["attention_mask"].to(device)
|
||||
|
||||
# Compute prompt lengths for completion_mean pooling
|
||||
gen_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gen_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Tokenize ground-truth sequences: prompt + ground_truth ---
|
||||
# For multi-turn (remaining_turns present), render the full GT conversation
|
||||
# through the chat template to preserve role markers between turns.
|
||||
gt_texts = []
|
||||
gt_prompt_texts = []
|
||||
for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
|
||||
if i % num_gens != 0:
|
||||
continue # Only need one GT per prompt group
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# Multi-turn: build full GT conversation with remaining turns
|
||||
if remaining_turns is not None:
|
||||
prompt_idx = i // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx]
|
||||
if prompt_idx < len(remaining_turns)
|
||||
else []
|
||||
)
|
||||
if turns:
|
||||
gt_conv = list(p) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(turns)
|
||||
full_gt_text = self.processing_class.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
gt_texts.append(full_gt_text)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
continue
|
||||
else:
|
||||
prompt_text = p
|
||||
gt_texts.append(prompt_text + gt)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
|
||||
gt_encoded = self.processing_class(
|
||||
text=gt_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gt_ids = gt_encoded["input_ids"].to(device)
|
||||
gt_mask = gt_encoded["attention_mask"].to(device)
|
||||
|
||||
gt_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gt_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Extract features from frozen feature network ---
|
||||
# INVARIANT: disable_adapter() yields the unmodified base weights because
|
||||
# _sync_peft_weights_no_merge and _sync_lora_adapter never call
|
||||
# merge_adapter() — they compute merged weights as new tensors or save
|
||||
# the adapter to filesystem. Base weights are never modified in-place.
|
||||
if self._share_feature_weights:
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
feature_ctx = unwrapped.disable_adapter()
|
||||
else:
|
||||
unwrapped = self.feature_network
|
||||
feature_ctx = contextlib.nullcontext()
|
||||
|
||||
with feature_ctx:
|
||||
was_training = unwrapped.training
|
||||
unwrapped.eval()
|
||||
gen_hidden = extract_hidden_states(
|
||||
unwrapped, gen_ids, gen_mask, self.feature_layer_indices
|
||||
)
|
||||
gt_hidden = extract_hidden_states(
|
||||
unwrapped, gt_ids, gt_mask, self.feature_layer_indices
|
||||
)
|
||||
if was_training:
|
||||
unwrapped.train()
|
||||
|
||||
# --- Pool to sequence-level embeddings ---
|
||||
gen_emb = apply_embed_method(
|
||||
gen_hidden,
|
||||
args.ebft_embed_method,
|
||||
gen_mask,
|
||||
prompt_lengths=gen_prompt_lengths,
|
||||
)
|
||||
gt_emb = apply_embed_method(
|
||||
gt_hidden,
|
||||
args.ebft_embed_method,
|
||||
gt_mask,
|
||||
prompt_lengths=gt_prompt_lengths,
|
||||
)
|
||||
|
||||
# --- Optional whitening ---
|
||||
batch_size = gen_emb.shape[0]
|
||||
if args.ebft_use_whitening and batch_size > 1:
|
||||
num_prompts = batch_size // num_gens
|
||||
gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
|
||||
whitened_gen_list = []
|
||||
whitened_gt_list = []
|
||||
for i in range(num_prompts):
|
||||
w_gen, w_gt = whiten_embeddings_batched(
|
||||
gen_reshaped[i], gt_emb[i : i + 1]
|
||||
)
|
||||
whitened_gen_list.append(w_gen)
|
||||
whitened_gt_list.append(w_gt)
|
||||
gen_emb = torch.cat(whitened_gen_list, dim=0)
|
||||
gt_emb = torch.cat(whitened_gt_list, dim=0)
|
||||
else:
|
||||
gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
|
||||
gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
|
||||
|
||||
# Repeat gt_emb: each GT repeated num_generations times
|
||||
gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
|
||||
|
||||
# --- Compute rewards ---
|
||||
alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
|
||||
diversity = get_diversity_rewards(gen_emb, num_gens)
|
||||
|
||||
# Scale by 2 per paper equation (7):
|
||||
# r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
|
||||
alignment = alignment * 2
|
||||
diversity = diversity * 2
|
||||
|
||||
rewards = (
|
||||
args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
|
||||
)
|
||||
|
||||
# Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
|
||||
gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
|
||||
mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
|
||||
cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
|
||||
|
||||
# Log feature-matching metrics to console and wandb
|
||||
_align = alignment.mean().item()
|
||||
_divers = diversity.mean().item()
|
||||
_reward = rewards.mean().item()
|
||||
_cfm = cfm_loss.item()
|
||||
|
||||
LOG.info(
|
||||
f"ebft reward | "
|
||||
f"align {_align:+.3f} ^ | "
|
||||
f"divers {_divers:+.3f} v | "
|
||||
f"cfm {_cfm:.3f} v | "
|
||||
f"reward {_reward:+.3f} ^"
|
||||
)
|
||||
|
||||
# Log to wandb via trainer's _metrics (picked up by GRPO's logging)
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if hasattr(self, "_metrics"):
|
||||
self._metrics[mode]["ebft/alignment"].append(_align)
|
||||
self._metrics[mode]["ebft/diversity"].append(_divers)
|
||||
self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
|
||||
self._metrics[mode]["ebft/reward"].append(_reward)
|
||||
|
||||
return rewards.cpu().tolist()
|
||||
|
||||
@torch.no_grad()
|
||||
def _sequential_rollout(
|
||||
self,
|
||||
prompts: list,
|
||||
first_completions: list,
|
||||
remaining_turns: list,
|
||||
num_gens: int,
|
||||
) -> list:
|
||||
"""
|
||||
Extend single-turn completions into multi-turn conversations.
|
||||
|
||||
For each prompt group, takes the first generated assistant turn and
|
||||
sequentially generates subsequent assistant turns by calling vLLM,
|
||||
building up a full multi-turn conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt message lists (repeated num_gens times)
|
||||
first_completions: List of generated first-turn completions
|
||||
remaining_turns: List of remaining turn pairs after first assistant turn.
|
||||
Each element is a list of dicts: [{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "...GT..."}]
|
||||
num_gens: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Extended completions incorporating all generated turns
|
||||
"""
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
max_tokens = getattr(self.args, "max_completion_length", 256)
|
||||
temperature = getattr(self.args, "temperature", 0.7)
|
||||
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
|
||||
|
||||
extended_completions = []
|
||||
|
||||
for idx in range(len(prompts)):
|
||||
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
|
||||
first_comp = first_completions[idx]
|
||||
|
||||
# Extract first completion text
|
||||
if isinstance(first_comp, list):
|
||||
first_text = first_comp[0].get("content", "") if first_comp else ""
|
||||
else:
|
||||
first_text = first_comp
|
||||
|
||||
# Get remaining turns for this prompt (same for all num_gens copies)
|
||||
prompt_idx = idx // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
|
||||
)
|
||||
|
||||
if not turns:
|
||||
extended_completions.append(first_text)
|
||||
continue
|
||||
|
||||
# Build conversation with generated first turn
|
||||
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
|
||||
|
||||
# Generate subsequent turns
|
||||
for turn in turns:
|
||||
if turn["role"] == "user":
|
||||
conv.append(turn)
|
||||
elif turn["role"] == "assistant":
|
||||
try:
|
||||
result = vllm_client.chat(
|
||||
messages=[conv],
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
generation_kwargs=gen_kwargs,
|
||||
)
|
||||
gen_ids = result.get("completion_ids", [[]])[0]
|
||||
gen_text = self.processing_class.decode(
|
||||
gen_ids, skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Multi-turn rollout generation failed: {e}")
|
||||
gen_text = ""
|
||||
|
||||
conv.append({"role": "assistant", "content": gen_text})
|
||||
|
||||
# Render full conversation through chat template, then extract
|
||||
# everything after the original prompt as the "completion" text.
|
||||
# This preserves role markers and formatting between turns.
|
||||
full_rendered = self.processing_class.apply_chat_template(
|
||||
conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
prompt_rendered = self.processing_class.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
completion_text = full_rendered[len(prompt_rendered) :]
|
||||
extended_completions.append(completion_text)
|
||||
|
||||
return extended_completions
|
||||
|
||||
|
||||
class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
|
||||
"""EBFT trainer using synchronous GRPO (standard vLLM generation)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
|
||||
"""EBFT trainer using async GRPO (prefetches next batch during training)."""
|
||||
|
||||
pass
|
||||
@@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.
|
||||
# The communicator is not needed because weight sync happens via filesystem + HTTP,
|
||||
# and it fails when vLLM and a trainer rank share the same CUDA device.
|
||||
# Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
|
||||
# merged weight sync. NCCL is only needed for the standard update_named_param
|
||||
# path which broadcasts tensors through the communicator.
|
||||
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
|
||||
if training_args is not None and getattr(
|
||||
training_args, "vllm_lora_sync", False
|
||||
):
|
||||
_skip_nccl = False
|
||||
if training_args is not None:
|
||||
if getattr(training_args, "vllm_lora_sync", False):
|
||||
_skip_nccl = True # LoRA sync uses filesystem + HTTP
|
||||
elif getattr(training_args, "async_prefetch", False):
|
||||
# Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
|
||||
# init_communicator allocates device tensors on rank 0 only, which
|
||||
# causes DDP to see different param counts across ranks.
|
||||
# The communicator is initialized lazily on first weight sync instead.
|
||||
_skip_nccl = True
|
||||
if _skip_nccl:
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
_orig_init_vllm = VLLMGeneration._init_vllm
|
||||
@@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
VLLMGeneration._init_vllm = _init_vllm_no_communicator
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
super().__init__(*args, **kwargs)
|
||||
finally:
|
||||
# Restore original _init_vllm so other trainers aren't affected
|
||||
if _skip_nccl:
|
||||
VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
|
||||
|
||||
# FP8 models: zero out the pad token embedding so that padding
|
||||
# positions have zero hidden states throughout the network.
|
||||
@@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._executor = None
|
||||
|
||||
def _submit_generation(self):
|
||||
"""Submit the next background generation job."""
|
||||
"""Submit the next background generation job.
|
||||
|
||||
With multi-process (DDP/FSDP), only rank 0 generates to avoid
|
||||
cross-rank NCCL collectives from background threads. Non-rank-0
|
||||
processes enqueue a sentinel ``None`` that is replaced by a
|
||||
broadcast in ``_prepare_inputs_legacy_async``.
|
||||
"""
|
||||
rank0_only = self.accelerator.num_processes > 1
|
||||
if rank0_only and not self.accelerator.is_main_process:
|
||||
# Non-rank-0: nothing to generate; enqueue a resolved None future
|
||||
f: concurrent.futures.Future = concurrent.futures.Future()
|
||||
f.set_result(None)
|
||||
self._async_queue.put(f)
|
||||
return
|
||||
batch = next(self._prompt_iter)
|
||||
future = self._executor.submit(self._generate_only, batch)
|
||||
future = self._executor.submit(self._generate_only, batch, rank0_only)
|
||||
self._async_queue.put(future)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcast rollout (legacy async, multi-process)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_rollout(self, rollout: dict | None) -> dict:
|
||||
"""Broadcast a rank0-only rollout dict to all ranks (main thread).
|
||||
|
||||
Rank 0 has the full rollout dict from ``_generate_only``; other ranks
|
||||
have ``None``. After broadcast, tensors are moved to each rank's
|
||||
local device.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
obj_list = [rollout if self.accelerator.is_main_process else None]
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
rollout = obj_list[0]
|
||||
assert rollout is not None, "broadcast_object_list failed to deliver rollout"
|
||||
|
||||
# Move tensors to local device (broadcast deserializes to CPU)
|
||||
device = self.accelerator.device
|
||||
for key, val in rollout.items():
|
||||
if isinstance(val, torch.Tensor) and val.device != device:
|
||||
rollout[key] = val.to(device)
|
||||
|
||||
return rollout
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight sync
|
||||
# ------------------------------------------------------------------
|
||||
@@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
for Float8), and also safe for concurrent use since it never modifies base
|
||||
weights in-place.
|
||||
"""
|
||||
model = self.vllm_generation.model
|
||||
accelerator = self.vllm_generation.accelerator
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
fix_name = self.vllm_generation._fix_param_name_to_vllm
|
||||
|
||||
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
|
||||
return
|
||||
|
||||
# In multi-GPU async mode, we skip NCCL communicator init to avoid
|
||||
# DDP param count mismatch and NCCL device conflicts. Weight sync
|
||||
# uses the HTTP-only fallback in batch_update_named_params instead.
|
||||
|
||||
model = self.vllm_generation.model
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
fix_name = self.vllm_generation._fix_param_name_to_vllm
|
||||
|
||||
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
|
||||
lora_info = {}
|
||||
for mod_name, module in model.base_model.model.named_modules():
|
||||
@@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
weight_name = pname.replace(".weight_scale_inv", ".weight")
|
||||
scale_inv_lookup[weight_name] = pparam.data
|
||||
|
||||
# Iterate all parameters, computing merged weights for LoRA layers.
|
||||
# Skip LoRA-specific params and FP8 scale params (scales will be
|
||||
# recomputed by vLLM when it receives the merged bf16 weight).
|
||||
# Only sync parameters that have LoRA modifications — skip unchanged
|
||||
# base weights to avoid OOM on the vLLM GPU from allocating the entire
|
||||
# model's worth of NCCL receive buffers.
|
||||
params_to_sync = []
|
||||
compute_dtype = torch.bfloat16
|
||||
for name, param in model.named_parameters():
|
||||
vllm_name = name.removeprefix("base_model.model.").replace(
|
||||
".base_layer", ""
|
||||
@@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
continue
|
||||
if "original_module" in vllm_name:
|
||||
continue
|
||||
# Skip FP8 quantization scale parameters - they are recomputed
|
||||
# on the vLLM side when we update the weight itself
|
||||
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
|
||||
continue
|
||||
if not vllm_name.endswith(".weight"):
|
||||
continue
|
||||
# fix_name strips modules_to_save.default. prefix
|
||||
raw_mod_path = vllm_name[: -len(".weight")]
|
||||
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
|
||||
# Sync weights that have LoRA adapters OR are modules_to_save
|
||||
is_lora = mod_path in lora_info
|
||||
is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
|
||||
if not is_lora and not is_modules_to_save:
|
||||
continue
|
||||
|
||||
data = param.data
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
if vllm_name.endswith(".weight"):
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
# Block dequantization: weight * scale_inv (with broadcasting)
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
# Block-quantized: scale_inv shape (rows/block, cols/block)
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr # block height
|
||||
bc = fp8_bf16.shape[1] // sc # block width
|
||||
# Reshape → multiply by block scale → reshape back
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
# Per-tensor or per-channel scale
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
# FP8 but no scale found - just cast (lossy)
|
||||
data = data.to(compute_dtype)
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr
|
||||
bc = fp8_bf16.shape[1] // sc
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
data = data.to(compute_dtype)
|
||||
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
if mod_path in lora_info:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
data = merged
|
||||
if is_lora:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
params_to_sync.append((vllm_name, merged))
|
||||
else:
|
||||
# modules_to_save: send raw weight (no LoRA merge needed)
|
||||
params_to_sync.append((vllm_name, data.to(compute_dtype)))
|
||||
|
||||
params_to_sync.append((vllm_name, data))
|
||||
|
||||
# Batch sync all params in one HTTP+NCCL call (vs individual calls)
|
||||
# Batch sync only LoRA-modified params via HTTP+NCCL
|
||||
if params_to_sync:
|
||||
sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
|
||||
logger.info(
|
||||
f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
|
||||
)
|
||||
vllm_client.batch_update_named_params(params_to_sync)
|
||||
|
||||
# Reset prefix cache after weight update
|
||||
@@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
url = f"{vllm_client.base_url}/set_lora_adapter/"
|
||||
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
|
||||
response = requests.post(
|
||||
url,
|
||||
json={
|
||||
@@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_path": adapter_path,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=sync_timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
@@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
if step == 0:
|
||||
logger.info("Skipping LoRA sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly
|
||||
self._sync_lora_adapter()
|
||||
else:
|
||||
@@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# Background-thread generation (no scoring)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_single_turn(self, prompts, **kwargs):
|
||||
def _generate_single_turn(self, prompts, *args, **kwargs):
|
||||
"""Override to prevent weight sync from background thread and to use
|
||||
no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
|
||||
is_bg = threading.current_thread() is not threading.main_thread()
|
||||
@@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._patched_sync_weights = True
|
||||
|
||||
try:
|
||||
return super()._generate_single_turn(prompts, **kwargs)
|
||||
return super()._generate_single_turn(prompts, *args, **kwargs)
|
||||
finally:
|
||||
if saved_step is not None:
|
||||
self._last_loaded_step = saved_step
|
||||
@@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
output = vg.vllm_client.chat(
|
||||
messages=unique_prompts,
|
||||
**sampling_params,
|
||||
chat_template_kwargs=vg.chat_template_kwargs,
|
||||
tools=vg.tools,
|
||||
chat_template=vg.chat_template,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
tools=self.tools,
|
||||
chat_template=getattr(self, "chat_template", None),
|
||||
)
|
||||
else:
|
||||
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
|
||||
@@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
logps_diff = per_token_logps_diff
|
||||
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
data["importance_sampling_ratio"] = is_ratio
|
||||
|
||||
# --- Collect rewards (launched before logprobs, should be done) ---
|
||||
@@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
seq_is = is_mode in ("sequence_mask", "sequence_truncate")
|
||||
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
# Symmetric floor clamp (matches non-streaming path at line ~1651)
|
||||
is_floor = 1.0 / is_cap
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
if "importance_sampling_ratio" not in data:
|
||||
total = len(data["prompt_ids"])
|
||||
shape = (total, 1) if seq_is else (total, is_ratio.size(1))
|
||||
@@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
rollout = future.result()
|
||||
self._submit_generation()
|
||||
|
||||
# With multi-process, only rank 0 generated. Broadcast to all ranks.
|
||||
if self.accelerator.num_processes > 1:
|
||||
rollout = self._broadcast_rollout(rollout)
|
||||
|
||||
if self.args.streaming_partial_batch:
|
||||
micro_batches = self._score_streaming(rollout)
|
||||
else:
|
||||
|
||||
@@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.axolotl_cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
if wandb.run is not None: # type: ignore[attr-defined]
|
||||
wandb.log( # type: ignore[attr-defined]
|
||||
{
|
||||
"generated_samples": wandb.Table(
|
||||
"generated_samples": wandb.Table( # type: ignore[attr-defined]
|
||||
columns=[
|
||||
"step",
|
||||
"original",
|
||||
|
||||
@@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__)
|
||||
def _batch_update_named_params(
|
||||
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
|
||||
):
|
||||
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
|
||||
from transformers import is_torch_xpu_available
|
||||
"""Batched weight sync — uses NCCL if communicator available, HTTP otherwise."""
|
||||
has_communicator = getattr(self, "communicator", None) is not None
|
||||
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
if has_communicator:
|
||||
# Fast path: metadata via HTTP, tensors via NCCL
|
||||
from transformers import is_torch_xpu_available
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(url, json={"params": param_metadata})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{
|
||||
"name": name,
|
||||
"dtype": str(weights.dtype),
|
||||
"shape": list(weights.shape),
|
||||
}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(
|
||||
url, json={"params": param_metadata}, timeout=120
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
else:
|
||||
# HTTP-only path: encode tensor data in request body (no NCCL needed).
|
||||
# Batch by byte size to avoid huge HTTP payloads.
|
||||
MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB
|
||||
HTTP_TIMEOUT = 120 # seconds per request
|
||||
|
||||
payload: list[dict] = []
|
||||
payload_bytes = 0
|
||||
url = f"{self.base_url}/http_update_weights/"
|
||||
|
||||
def _flush(p: list[dict]) -> None:
|
||||
if not p:
|
||||
return
|
||||
response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
from axolotl.utils.weight_serde import encode_for_http
|
||||
|
||||
for name, weights in params:
|
||||
entry = encode_for_http(name, weights)
|
||||
entry_bytes = weights.nelement() * weights.element_size()
|
||||
|
||||
# Flush current batch if adding this entry would exceed limit
|
||||
if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST:
|
||||
_flush(payload)
|
||||
payload = []
|
||||
payload_bytes = 0
|
||||
|
||||
payload.append(entry)
|
||||
payload_bytes += entry_bytes
|
||||
|
||||
_flush(payload) # send remaining
|
||||
|
||||
|
||||
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
|
||||
|
||||
9
src/axolotl/prompt_strategies/ebft/__init__.py
Normal file
9
src/axolotl/prompt_strategies/ebft/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
module for EBFT style dataset transform strategies
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module_base="axolotl.prompt_strategies.ebft")
|
||||
129
src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
Normal file
129
src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Dataset transform for multi-turn chat data with structured EBFT (vLLM mode).
|
||||
|
||||
Three variants:
|
||||
|
||||
1. `transform` — Uses the FIRST assistant turn as the generation target.
|
||||
Passes remaining turns as `remaining_turns` for sequential rollout.
|
||||
The trainer generates turn 1 via GRPO/vLLM, then sequentially generates
|
||||
subsequent assistant turns, comparing the full conversation to GT.
|
||||
|
||||
2. `transform_last_turn` — Uses the LAST assistant turn as the target.
|
||||
Simplest approach: the full conversation history is the prompt.
|
||||
|
||||
3. `transform_all_turns` — Explodes each conversation into N examples
|
||||
(one per assistant turn). Each turn is an independent training example.
|
||||
Use with batched=True.
|
||||
|
||||
Supports OpenAI chat format:
|
||||
{"messages": [{"role": ..., "content": ...}, ...]}
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
"""Multi-turn with sequential rollout.
|
||||
|
||||
Returns the first assistant turn as ground_truth, plus remaining_turns
|
||||
for the trainer to do sequential rollout generation.
|
||||
"""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if not messages:
|
||||
return {"prompt": [], "ground_truth": ""}
|
||||
|
||||
# Split at first assistant turn
|
||||
prompt_msgs = []
|
||||
first_gt = None
|
||||
remaining = []
|
||||
|
||||
found_first = False
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant" and not found_first:
|
||||
first_gt = msg["content"]
|
||||
found_first = True
|
||||
elif found_first:
|
||||
remaining.append(msg)
|
||||
else:
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
if first_gt is None:
|
||||
return {"prompt": prompt_msgs, "ground_truth": ""}
|
||||
|
||||
# Store only the first assistant turn as ground_truth. The full multi-turn
|
||||
# GT is reconstructed in the reward function via chat template rendering
|
||||
# (using remaining_turns), which preserves role markers between turns.
|
||||
return {
|
||||
"prompt": prompt_msgs,
|
||||
"ground_truth": first_gt,
|
||||
"remaining_turns": remaining,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
|
||||
|
||||
def transform_last_turn(cfg, **kwargs):
|
||||
"""Single-turn: use the last assistant turn as the generation target."""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if not messages:
|
||||
return {"prompt": [], "ground_truth": ""}
|
||||
|
||||
# Find all assistant turns
|
||||
history = []
|
||||
last_prompt = []
|
||||
last_gt = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
last_prompt = list(history)
|
||||
last_gt = msg["content"]
|
||||
history.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": last_prompt,
|
||||
"ground_truth": last_gt,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
|
||||
|
||||
def transform_all_turns(cfg, **kwargs):
|
||||
"""Explode: one example per assistant turn.
|
||||
|
||||
Use with datasets.map(batched=True) to produce N examples from
|
||||
each N-turn conversation.
|
||||
|
||||
Usage in YAML:
|
||||
type: ebft_chat_multiturn.transform_all_turns
|
||||
"""
|
||||
|
||||
def transform_fn(examples, tokenizer=None):
|
||||
all_prompts = []
|
||||
all_ground_truths = []
|
||||
|
||||
messages_list = examples.get("messages", examples.get("conversations", []))
|
||||
|
||||
for messages in messages_list:
|
||||
history = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
all_prompts.append(list(history))
|
||||
all_ground_truths.append(msg["content"])
|
||||
history.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": all_prompts,
|
||||
"ground_truth": all_ground_truths,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
"batched": True,
|
||||
}
|
||||
20
src/axolotl/prompt_strategies/ebft/ebft_opencode.py
Normal file
20
src/axolotl/prompt_strategies/ebft/ebft_opencode.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode.
|
||||
|
||||
Maps the dataset's `input` (prompt) and `output` (code solution) fields
|
||||
to the format expected by the EBFT trainer (prompt + ground_truth).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "user", "content": example["input"]},
|
||||
],
|
||||
"ground_truth": example["output"],
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
319
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
Normal file
319
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Dataset transform for reasoning/thinking datasets with EBFT.
|
||||
|
||||
Handles datasets where assistant responses contain <think>...</think> reasoning
|
||||
traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs).
|
||||
|
||||
Two variants:
|
||||
|
||||
1. `transform` — For structured EBFT (vLLM mode):
|
||||
Returns prompt + ground_truth with thinking tags preserved.
|
||||
Feature matching compares full responses (thinking + answer).
|
||||
|
||||
2. `transform_answer_only` — For structured EBFT (vLLM mode):
|
||||
Strips <think>...</think> from ground_truth, so feature matching
|
||||
only scores the final answer portion. Use when reasoning chains
|
||||
can vary but the answer should match.
|
||||
|
||||
3. `transform_strided` — For strided EBFT:
|
||||
Tokenizes the full conversation with thinking traces.
|
||||
Optionally masks thinking tokens from CE loss (labels=-100 for think spans)
|
||||
while still placing anchors in thinking regions for feature matching.
|
||||
|
||||
All variants work with OpenAI chat format:
|
||||
{"messages": [{"role": "...", "content": "<think>...</think>Answer"}]}
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _strip_thinking(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks from text."""
|
||||
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
def _extract_thinking(text: str) -> tuple[str, str]:
|
||||
"""Split text into (thinking, answer) parts."""
|
||||
match = re.search(r"<think>(.*?)</think>\s*(.*)", text, flags=re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip(), match.group(2).strip()
|
||||
return "", text.strip()
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
"""Full response including thinking traces for feature matching.
|
||||
|
||||
For datasets where assistant content has <think>...</think> tags in the
|
||||
content field. The ground_truth includes the full content (thinking + answer).
|
||||
"""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(messages[:msg_idx])
|
||||
ground_truth = msg["content"]
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_split_thinking(cfg, **kwargs):
|
||||
"""Split <think> tags into reasoning_content field for native chat template handling.
|
||||
|
||||
For datasets where thinking is embedded in the content field as <think>...</think>.
|
||||
Splits it into separate reasoning_content and content fields so the model's
|
||||
chat template can format it natively (e.g., Qwen3.5's reasoning_content support).
|
||||
|
||||
The prompt messages are passed through with reasoning_content properly split,
|
||||
so vLLM generation with enable_thinking=true produces comparable outputs.
|
||||
The ground_truth is the full assistant response (thinking + answer) for
|
||||
feature matching.
|
||||
|
||||
Also works for:
|
||||
- <reasoning>...</reasoning> tags
|
||||
- <|begin_of_thought|>...<|end_of_thought|> tags
|
||||
"""
|
||||
_THINKING_PAIRS = [
|
||||
("<think>", "</think>"),
|
||||
("<reasoning>", "</reasoning>"),
|
||||
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
||||
]
|
||||
|
||||
def _split_msg_thinking(msg):
|
||||
"""Split thinking from assistant message content into reasoning_content.
|
||||
|
||||
Always includes reasoning_content key on assistant messages (empty string
|
||||
if no thinking tags found) to ensure consistent HF dataset schema across
|
||||
all examples in a batch.
|
||||
"""
|
||||
if msg["role"] != "assistant":
|
||||
return msg
|
||||
content = msg.get("content", "")
|
||||
# Already has reasoning_content — pass through
|
||||
if "reasoning_content" in msg:
|
||||
return msg
|
||||
for open_tag, close_tag in _THINKING_PAIRS:
|
||||
if open_tag in content and close_tag in content:
|
||||
start = content.find(open_tag)
|
||||
end = content.find(close_tag)
|
||||
thinking = content[start + len(open_tag) : end].strip()
|
||||
answer = content[end + len(close_tag) :].strip()
|
||||
return {
|
||||
**msg,
|
||||
"reasoning_content": thinking,
|
||||
"content": answer,
|
||||
}
|
||||
# No thinking tags — still add reasoning_content for schema consistency
|
||||
return {**msg, "reasoning_content": ""}
|
||||
|
||||
def _normalize_msg(msg):
|
||||
"""Ensure every message has {role, content, reasoning_content} for HF schema consistency."""
|
||||
return {
|
||||
"role": msg.get("role", ""),
|
||||
"content": msg.get("content", ""),
|
||||
"reasoning_content": msg.get("reasoning_content", ""),
|
||||
}
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
# Split thinking in all assistant messages, then normalize schema
|
||||
split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages]
|
||||
|
||||
# Build prompt (all messages except last assistant) and ground_truth
|
||||
prompt_msgs = []
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg in split_messages:
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(prompt_msgs)
|
||||
# ground_truth is the FULL content for feature matching
|
||||
thinking = msg.get("reasoning_content", "")
|
||||
answer = msg.get("content", "")
|
||||
if thinking:
|
||||
ground_truth = f"<think>\n{thinking}\n</think>\n\n{answer}"
|
||||
else:
|
||||
ground_truth = answer
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else split_messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_answer_only(cfg, **kwargs):
|
||||
"""Strip thinking from ground_truth — match features on answer only."""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
prompt_msgs = []
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(prompt_msgs)
|
||||
ground_truth = _strip_thinking(msg["content"])
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_strided(cfg, **kwargs):
|
||||
"""For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss.
|
||||
|
||||
Config options (via cfg):
|
||||
- ebft.mask_thinking_ce: bool (default False)
|
||||
If True, set labels=-100 for tokens inside <think>...</think> blocks.
|
||||
Feature matching still uses these positions (anchors are placed everywhere
|
||||
in the completion span). Only CE auxiliary loss is affected.
|
||||
"""
|
||||
seq_len = cfg.sequence_len
|
||||
mask_thinking = False
|
||||
if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"):
|
||||
mask_thinking = cfg.ebft.mask_thinking_ce
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if tokenizer is None:
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
return {"prompt": m["content"]}
|
||||
return {"prompt": str(messages)}
|
||||
|
||||
pad_id = (
|
||||
tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Tokenize the full conversation with the chat template
|
||||
full_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
full_enc = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
input_ids = full_enc["input_ids"]
|
||||
|
||||
# Build labels: -100 for non-assistant tokens
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# Find assistant turn boundaries using incremental tokenization.
|
||||
# Only the FINAL assistant turn is marked as trainable.
|
||||
prefix_messages = []
|
||||
final_start = None
|
||||
final_end = None
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
prefix_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prefix_ids = tokenizer(
|
||||
prefix_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
start = len(prefix_ids)
|
||||
|
||||
prefix_messages.append(msg)
|
||||
with_turn_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
with_turn_ids = tokenizer(
|
||||
with_turn_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
end = len(with_turn_ids)
|
||||
|
||||
# Record this turn's boundaries; only the last one will be used
|
||||
final_start = start
|
||||
final_end = end
|
||||
else:
|
||||
prefix_messages.append(msg)
|
||||
|
||||
# Mark only the final assistant turn as trainable
|
||||
if final_start is not None and final_end is not None:
|
||||
for i in range(final_start, min(final_end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
# Optionally mask <think>...</think> tokens within this turn.
|
||||
# Find think spans by scanning for <think> and </think> token IDs
|
||||
# directly in the input_ids (robust to tokenization alignment).
|
||||
if mask_thinking:
|
||||
think_open_id = tokenizer.convert_tokens_to_ids("<think>")
|
||||
think_close_id = tokenizer.convert_tokens_to_ids("</think>")
|
||||
if think_open_id != tokenizer.unk_token_id:
|
||||
# Scan from before the assistant turn start to catch
|
||||
# <think> tags that are part of the template prefix
|
||||
scan_start = max(0, final_start - 5)
|
||||
in_think = False
|
||||
for i in range(scan_start, min(final_end, len(labels))):
|
||||
if input_ids[i] == think_open_id:
|
||||
in_think = True
|
||||
if in_think and i >= final_start:
|
||||
labels[i] = -100
|
||||
if input_ids[i] == think_close_id:
|
||||
in_think = False
|
||||
if i >= final_start:
|
||||
labels[i] = -100
|
||||
|
||||
# Derive prompt_length
|
||||
prompt_length = len(input_ids)
|
||||
for i, lbl in enumerate(labels):
|
||||
if lbl != -100:
|
||||
prompt_length = i
|
||||
break
|
||||
|
||||
# Pad
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
110
src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
Normal file
110
src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Dataset transform for multi-turn chat data with strided EBFT.
|
||||
|
||||
Tokenizes conversations using the model's chat template, producing input_ids
|
||||
with labels=-100 for system/user turns and real labels for assistant turns.
|
||||
The strided trainer places anchors only within assistant completion spans.
|
||||
|
||||
Works with datasets in OpenAI chat format:
|
||||
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if tokenizer is None:
|
||||
# For preview: just return the first user message
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
return {"prompt": m["content"]}
|
||||
return {"prompt": str(messages)}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize the full conversation with the chat template
|
||||
full_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
full_enc = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
input_ids = full_enc["input_ids"]
|
||||
|
||||
# Build labels: -100 for everything except assistant turns.
|
||||
# Strategy: tokenize incrementally to find assistant turn boundaries.
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# Tokenize prefix up to each assistant turn to find boundaries
|
||||
prefix_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
# Tokenize prefix (everything before this assistant turn + generation prompt)
|
||||
prefix_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prefix_ids = tokenizer(
|
||||
prefix_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
start = len(prefix_ids)
|
||||
|
||||
# Tokenize prefix + this assistant turn
|
||||
prefix_messages.append(msg)
|
||||
with_turn_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
with_turn_ids = tokenizer(
|
||||
with_turn_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
end = len(with_turn_ids)
|
||||
|
||||
# Mark assistant tokens as trainable
|
||||
for i in range(start, min(end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
else:
|
||||
prefix_messages.append(msg)
|
||||
|
||||
# Derive prompt_length as the position of the first non-masked label
|
||||
prompt_length = len(input_ids) # default: all masked
|
||||
for i, lbl in enumerate(labels):
|
||||
if lbl != -100:
|
||||
prompt_length = i
|
||||
break
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
@@ -241,6 +241,23 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Access logging middleware ---
|
||||
import time as _time
|
||||
|
||||
@app.middleware("http")
|
||||
async def access_log_middleware(request, call_next):
|
||||
t0 = _time.monotonic()
|
||||
response = await call_next(request)
|
||||
elapsed = _time.monotonic() - t0
|
||||
logger.info(
|
||||
"%s %s %d %.3fs",
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
elapsed,
|
||||
)
|
||||
return response
|
||||
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
@@ -300,7 +317,11 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
import vllm
|
||||
from packaging.version import Version
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
except ImportError:
|
||||
GuidedDecodingParams = None # not available in vLLM 0.17+
|
||||
|
||||
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
|
||||
prompts: list[dict[str, Any]] = []
|
||||
@@ -362,7 +383,12 @@ def main(script_args: ScriptArguments):
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
# Use run_in_executor so blocking recv() doesn't freeze the event loop
|
||||
# (allows /set_lora_adapter/ and other endpoints to be served concurrently)
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
|
||||
]
|
||||
@@ -404,7 +430,10 @@ def main(script_args: ScriptArguments):
|
||||
}
|
||||
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
@@ -474,11 +503,51 @@ def main(script_args: ScriptArguments):
|
||||
)
|
||||
return {"message": f"Batch update for {len(params_list)} params"}
|
||||
|
||||
class HTTPWeightUpdateRequest(BaseModel):
|
||||
"""Weight update via HTTP (no NCCL needed)."""
|
||||
|
||||
params: list[
|
||||
dict
|
||||
] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
|
||||
|
||||
@app.post("/http_update_weights/")
|
||||
async def http_update_weights(request: HTTPWeightUpdateRequest):
|
||||
"""Update model weights via HTTP — no NCCL communicator required.
|
||||
|
||||
Tensor data is sent as base64-encoded raw bytes in the request body.
|
||||
Slower than NCCL for large models but works without cross-process setup.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import (
|
||||
decode_from_http,
|
||||
encode_for_ipc,
|
||||
)
|
||||
|
||||
weights_to_load = [decode_from_http(p) for p in request.params]
|
||||
|
||||
# Send all weights in a single IPC call. Tensors don't survive
|
||||
# vLLM's multiproc IPC, so serialize as raw bytes + metadata.
|
||||
param_entries = [
|
||||
encode_for_ipc(name, weight) for name, weight in weights_to_load
|
||||
]
|
||||
kwargs = {
|
||||
"method": "http_load_weights_batch",
|
||||
"kwargs": {"params": param_entries},
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
|
||||
|
||||
@app.post("/reset_prefix_cache/")
|
||||
async def reset_prefix_cache():
|
||||
for conn in connections:
|
||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
||||
results = [conn.recv() for conn in connections]
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
|
||||
@@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
model = self.model_runner.model
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
# Handle VLM models where trainer and vLLM use different prefixes.
|
||||
# Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..."
|
||||
# vLLM (Qwen3.5): "language_model.model.layers.X..."
|
||||
if name not in params_dict:
|
||||
# Try common prefix remappings
|
||||
for src_prefix, dst_prefix in [
|
||||
("model.language_model.layers.", "language_model.model.layers."),
|
||||
("model.layers.", "language_model.model.layers."),
|
||||
]:
|
||||
if name.startswith(src_prefix):
|
||||
name = dst_prefix + name[len(src_prefix) :]
|
||||
break
|
||||
|
||||
# Check if this is a simple direct param (exists as-is)
|
||||
if name in params_dict:
|
||||
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
|
||||
@@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
return
|
||||
|
||||
# Fallback: try load_weights (may work for non-stacked params)
|
||||
logger.warning("Falling back to load_weights for param: %s", name)
|
||||
# Log the actual param names available for debugging
|
||||
sample_keys = [
|
||||
k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k
|
||||
][:3]
|
||||
logger.warning(
|
||||
"Falling back to load_weights for param: %s (sample vLLM keys: %s)",
|
||||
name,
|
||||
sample_keys,
|
||||
)
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def update_named_param(self, name, dtype, shape):
|
||||
@@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
# Load weights using direct set (handles stacked params)
|
||||
for name, weight in weights_to_load:
|
||||
self._direct_set_weight(name, weight)
|
||||
|
||||
def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]):
|
||||
"""Load weights received via HTTP (no NCCL needed)."""
|
||||
for name, weight in weights:
|
||||
self._direct_set_weight(name, weight.to(self.device))
|
||||
|
||||
def http_load_weight(self, **kwargs):
|
||||
"""Load a single weight received via HTTP (no NCCL needed).
|
||||
|
||||
Reconstructs the tensor from raw bytes since tensors don't survive
|
||||
vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights``
|
||||
which handles TP sharding and stacked-param packing automatically.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import decode_from_ipc
|
||||
|
||||
name, weight = decode_from_ipc(kwargs)
|
||||
model = self.model_runner.model
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def http_load_weights_batch(self, params: list[dict]):
|
||||
"""Load multiple weights in a single IPC call.
|
||||
|
||||
Uses vLLM's ``load_weights`` which handles TP sharding automatically.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import decode_from_ipc
|
||||
|
||||
model = self.model_runner.model
|
||||
weights = [decode_from_ipc(p) for p in params]
|
||||
model.load_weights(weights=weights)
|
||||
|
||||
@@ -138,7 +138,11 @@ def setup_reference_model(
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
reference_model: bool = True
|
||||
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if (
|
||||
cfg.rl in {RLType.GRPO, RLType.EBFT}
|
||||
and getattr(trl_cfg, "beta", 0) == 0
|
||||
):
|
||||
reference_model = False
|
||||
# load the model again for model_ref/baseline
|
||||
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
|
||||
@@ -206,7 +210,7 @@ def execute_training(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT},
|
||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
].append(pred_step_text)
|
||||
row_index += 1
|
||||
if logger == "wandb":
|
||||
# type: ignore[attr-defined]
|
||||
wandb.run.log(
|
||||
wandb.run.log( # type: ignore[attr-defined]
|
||||
{
|
||||
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
|
||||
table_data
|
||||
@@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
artifact = wandb.Artifact(
|
||||
f"config-{wandb.run.id}", type="axolotl-config"
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"config-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="axolotl-config",
|
||||
)
|
||||
artifact.add_file(temp_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the WandB run under files."
|
||||
)
|
||||
@@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
temp_ct_file.write(str(chat_tpl))
|
||||
temp_ct_file.flush()
|
||||
|
||||
artifact = wandb.Artifact(
|
||||
f"chat-template-{wandb.run.id}", type="jinja-template"
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"chat-template-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="jinja-template",
|
||||
)
|
||||
artifact.add_file(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_ct_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The chat_template_jinja has been saved to the WandB run under files."
|
||||
)
|
||||
@@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
else:
|
||||
skip_upload = True
|
||||
if not skip_upload:
|
||||
artifact = wandb.Artifact(
|
||||
f"deepspeed-config-{wandb.run.id}",
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="deepspeed-config",
|
||||
)
|
||||
artifact.add_file(temp_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The DeepSpeed config has been saved to the WandB run under files."
|
||||
)
|
||||
|
||||
@@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback):
|
||||
if not getattr(cfg, "generate_samples", False):
|
||||
return
|
||||
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
LOG.info(
|
||||
f"Using eval dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||
dataloader = None
|
||||
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
LOG.info(
|
||||
f"Using train dataloader for generation at step {state.global_step}"
|
||||
f"Using eval dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||
dataloader = None
|
||||
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||
top_p=getattr(cfg, "generation_top_p", None),
|
||||
top_k=getattr(cfg, "generation_top_k", None),
|
||||
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
LOG.info(
|
||||
f"Using train dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||
top_p=getattr(cfg, "generation_top_p", None),
|
||||
top_k=getattr(cfg, "generation_top_k", None),
|
||||
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||
)
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
def _log_samples(self, samples: list, step: int):
|
||||
"""Log generated samples to console and W&B."""
|
||||
@@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback):
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
if wandb.run is not None: # type: ignore[attr-defined]
|
||||
wandb.log( # type: ignore[attr-defined]
|
||||
{
|
||||
f"samples/sample_{i + 1}": wandb.Html(
|
||||
f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined]
|
||||
f"<pre>{wandb_text}</pre>"
|
||||
)
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
@@ -173,7 +174,7 @@ def _drop_long_sequences(
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
if rl in {RLType.GRPO, RLType.GDPO}:
|
||||
if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}:
|
||||
return True
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
@@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.EBFT:
|
||||
ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs: dict[str, Any] = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
# Handle remove_columns: "__all__" removes all original columns,
|
||||
# or filter a list to only columns that exist in the dataset
|
||||
if "remove_columns" in map_kwargs:
|
||||
ds_columns = (
|
||||
dataset.column_names
|
||||
if isinstance(dataset, Dataset)
|
||||
else dataset[split].column_names
|
||||
if isinstance(dataset, DatasetDict)
|
||||
else []
|
||||
)
|
||||
if map_kwargs["remove_columns"] == "__all__":
|
||||
map_kwargs["remove_columns"] = list(ds_columns)
|
||||
else:
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
split_datasets[i] = _map_dataset(
|
||||
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
|
||||
@@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EBFTConfig(BaseModel):
|
||||
"""Configuration for Energy-Based Fine-Tuning (EBFT)"""
|
||||
|
||||
feature_layers: list[float] = Field(
|
||||
default=[0.25, 0.5, 0.75],
|
||||
json_schema_extra={
|
||||
"description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])"
|
||||
},
|
||||
)
|
||||
embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = (
|
||||
Field(
|
||||
default="last_token",
|
||||
json_schema_extra={
|
||||
"description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'"
|
||||
},
|
||||
)
|
||||
)
|
||||
use_whitening: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
|
||||
)
|
||||
alignment_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient for alignment reward (cosine similarity with ground truth)"
|
||||
},
|
||||
)
|
||||
diversity_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient for diversity penalty (pairwise similarity between samples)"
|
||||
},
|
||||
)
|
||||
ce_coef: float = Field(
|
||||
default=0.0,
|
||||
json_schema_extra={
|
||||
"description": "Cross-entropy loss coefficient on ground-truth tokens"
|
||||
},
|
||||
)
|
||||
adaptive_max_tokens: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Set per-batch max_tokens based on ground-truth length"
|
||||
},
|
||||
)
|
||||
gt_length_multiplier: float = Field(
|
||||
default=1.5,
|
||||
ge=0.1,
|
||||
json_schema_extra={
|
||||
"description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
|
||||
},
|
||||
)
|
||||
|
||||
# Strided mode fields (for unstructured text)
|
||||
mode: Literal["structured", "strided"] = Field(
|
||||
default="structured",
|
||||
json_schema_extra={
|
||||
"description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
|
||||
},
|
||||
)
|
||||
stride: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Stride between anchor points (tokens)"},
|
||||
)
|
||||
context_length: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Context window size per block"},
|
||||
)
|
||||
generate_max_len: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Tokens to generate per block"},
|
||||
)
|
||||
n_samples_per_prompt: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Independent rollouts per document"},
|
||||
)
|
||||
temperature: float = Field(
|
||||
default=0.6,
|
||||
ge=0.0,
|
||||
json_schema_extra={
|
||||
"description": "Sampling temperature for strided generation"
|
||||
},
|
||||
)
|
||||
top_p: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
json_schema_extra={"description": "Top-p nucleus sampling threshold"},
|
||||
)
|
||||
rl_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={"description": "RL policy gradient loss coefficient"},
|
||||
)
|
||||
advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
|
||||
default="rloo",
|
||||
json_schema_extra={
|
||||
"description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
|
||||
},
|
||||
)
|
||||
min_completion_prefix: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"description": "Minimum tokens into completion before placing anchors. "
|
||||
"Skips anchors too close to the prompt boundary where features are dominated by prompt context."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
ModelOutputConfig,
|
||||
@@ -131,7 +244,7 @@ class AxolotlInputConfig(
|
||||
rl: RLType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'"
|
||||
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'"
|
||||
},
|
||||
)
|
||||
trl: TRLConfig | None = Field(
|
||||
@@ -140,6 +253,12 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
ebft: EBFTConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
|
||||
},
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
|
||||
@@ -35,6 +35,7 @@ class RLType(str, Enum):
|
||||
ORPO = "orpo"
|
||||
KTO = "kto"
|
||||
SIMPO = "simpo"
|
||||
EBFT = "ebft"
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pydantic models for TRL trainer configuration"""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -133,6 +133,20 @@ class TRLConfig(BaseModel):
|
||||
"description": "Penalty for tokens that appear in prompt and generated text."
|
||||
},
|
||||
)
|
||||
generation_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Additional generation parameters passed to vLLM SamplingParams. "
|
||||
"Useful for stop_token_ids, seed, frequency_penalty, etc."
|
||||
},
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Additional kwargs for the chat template. "
|
||||
"E.g., {enable_thinking: false} for Qwen3.5 models."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1482,6 +1482,124 @@ class DistributedValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
class EBFTValidationMixin:
|
||||
"""Validation for EBFT (Energy-Based Fine-Tuning) configuration."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_config_required(cls, data):
|
||||
"""rl: ebft requires an ebft config section."""
|
||||
if data.get("rl") == "ebft" and not data.get("ebft"):
|
||||
raise ValueError(
|
||||
"`ebft` config section is required when `rl: ebft` is set. "
|
||||
"Add an `ebft:` section with at least `mode: structured` or `mode: strided`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_torch_compile(cls, data):
|
||||
"""torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles
|
||||
and CheckpointErrors. The flex_attention kernel compiles itself internally —
|
||||
whole-model torch.compile is not needed and actively harmful."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("torch_compile") is True
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
):
|
||||
if data.get("gradient_checkpointing"):
|
||||
raise ValueError(
|
||||
"EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` "
|
||||
"causes CheckpointError (BlockMask metadata mismatch during recomputation). "
|
||||
"Remove `torch_compile` — the flex_attention kernel compiles itself internally."
|
||||
)
|
||||
LOG.warning(
|
||||
"EBFT strided mode: `torch_compile: true` causes dynamo recompiles from "
|
||||
"variable sequence lengths across steps. Consider removing it — "
|
||||
"flex_attention compiles itself internally."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_gradient_checkpointing_reentrant(cls, data):
|
||||
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("flex_attention")
|
||||
and data.get("gradient_checkpointing")
|
||||
):
|
||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
||||
if not gc_kwargs.get("use_reentrant"):
|
||||
LOG.warning(
|
||||
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
|
||||
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
|
||||
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
|
||||
)
|
||||
if data.get("gradient_checkpointing_kwargs") is None:
|
||||
data["gradient_checkpointing_kwargs"] = {}
|
||||
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_activation_offloading(cls, data):
|
||||
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
|
||||
which conflicts with flex_attention's use_reentrant requirement."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("activation_offloading") is True
|
||||
and data.get("flex_attention")
|
||||
):
|
||||
raise ValueError(
|
||||
"EBFT strided mode: `activation_offloading: true` is incompatible with "
|
||||
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
|
||||
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
|
||||
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
|
||||
"uses micro-batched forward passes for memory efficiency instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_strided_sequence_len(cls, data):
|
||||
"""Warn if sequence_len is too large for single-GPU strided EBFT."""
|
||||
if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided":
|
||||
return data
|
||||
ebft = data.get("ebft", {})
|
||||
seq_len = data.get("sequence_len", 512)
|
||||
n_samples = ebft.get("n_samples_per_prompt", 4)
|
||||
gen_len = ebft.get("generate_max_len", 8)
|
||||
stride = ebft.get("stride", 8)
|
||||
ctx_len = ebft.get("context_length", 8)
|
||||
max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
|
||||
full_seq = seq_len + max_blocks * gen_len
|
||||
# Rough estimate: 8.7 GB per sample at S=3900 for 1B model
|
||||
if full_seq * n_samples > 20000:
|
||||
LOG.warning(
|
||||
f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = "
|
||||
f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM "
|
||||
f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_strided_dataset_split(cls, data):
|
||||
"""Warn about the common `train_on_split` mistake (silently ignored by schema)."""
|
||||
datasets = data.get("datasets", [])
|
||||
for ds in datasets or []:
|
||||
if isinstance(ds, dict) and ds.get("train_on_split"):
|
||||
LOG.warning(
|
||||
f"Dataset has `train_on_split: {ds['train_on_split']}` — this field "
|
||||
f"is not recognized and will be silently ignored. "
|
||||
f"Use `split: {ds['train_on_split']}` instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class GRPOVllmValidationMixin:
|
||||
"""Validation mixin for vllm when using GRPO."""
|
||||
|
||||
@@ -1507,6 +1625,7 @@ class ValidationMixin(
|
||||
PretrainingValidationMixin,
|
||||
ModelCompatibilityValidationMixin,
|
||||
ComplexValidationMixin,
|
||||
EBFTValidationMixin,
|
||||
GRPOVllmValidationMixin,
|
||||
):
|
||||
"""Full validation mixin for Axolotl configuration."""
|
||||
|
||||
@@ -57,6 +57,13 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Reasoning parser for VLLM"},
|
||||
)
|
||||
enforce_eager: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Disable CUDA graph capture in vLLM. Required for models with "
|
||||
"causal_conv1d (e.g., Qwen3.5 hybrid linear attention)."
|
||||
},
|
||||
)
|
||||
serve_module: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
94
src/axolotl/utils/weight_serde.py
Normal file
94
src/axolotl/utils/weight_serde.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Serialize / deserialize tensors for HTTP and IPC weight sync.
|
||||
|
||||
NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire
|
||||
and reconstructed at the destination. All encode/decode helpers live here so
|
||||
the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and
|
||||
vllm_worker_ext.py.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def encode_for_http(name: str, weight: torch.Tensor) -> dict:
|
||||
"""Encode a named parameter for JSON transport over HTTP.
|
||||
|
||||
Returns a dict with keys: name, dtype (original), shape, data (base64).
|
||||
bf16 tensors are sent as fp16 bytes; the original dtype is preserved in
|
||||
the ``dtype`` field so the receiver can cast back.
|
||||
"""
|
||||
w_cpu = weight.contiguous().cpu()
|
||||
orig_dtype = str(weight.dtype)
|
||||
if w_cpu.dtype == torch.bfloat16:
|
||||
w_cpu = w_cpu.half()
|
||||
raw = w_cpu.numpy().tobytes()
|
||||
return {
|
||||
"name": name,
|
||||
"dtype": orig_dtype,
|
||||
"shape": list(weight.shape),
|
||||
"data": base64.b64encode(raw).decode("ascii"),
|
||||
}
|
||||
|
||||
|
||||
def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]:
|
||||
"""Decode an HTTP-encoded weight entry back to a named tensor.
|
||||
|
||||
Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the
|
||||
original dtype stored in ``entry["dtype"]``.
|
||||
"""
|
||||
target_dtype = getattr(torch, entry["dtype"].split(".")[-1])
|
||||
shape = tuple(entry["shape"])
|
||||
raw = base64.b64decode(entry["data"])
|
||||
|
||||
n_elements = 1
|
||||
for s in shape:
|
||||
n_elements *= s
|
||||
wire_bytes_per_elem = len(raw) // max(n_elements, 1)
|
||||
if wire_bytes_per_elem == 2:
|
||||
wire_dtype = torch.float16
|
||||
elif wire_bytes_per_elem == 4:
|
||||
wire_dtype = torch.float32
|
||||
else:
|
||||
wire_dtype = target_dtype
|
||||
|
||||
weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape)
|
||||
if wire_dtype != target_dtype:
|
||||
weight = weight.to(target_dtype)
|
||||
return entry["name"], weight
|
||||
|
||||
|
||||
def encode_for_ipc(name: str, weight: torch.Tensor) -> dict:
|
||||
"""Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64).
|
||||
|
||||
Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype
|
||||
(original), shape. bf16 tensors are serialized as fp16.
|
||||
"""
|
||||
w = weight.contiguous()
|
||||
target_dtype = str(w.dtype).split(".")[-1]
|
||||
if w.dtype == torch.bfloat16:
|
||||
w = w.half()
|
||||
wire_dtype = str(w.dtype).split(".")[-1]
|
||||
return {
|
||||
"name": name,
|
||||
"data": w.numpy().tobytes(),
|
||||
"dtype": wire_dtype,
|
||||
"target_dtype": target_dtype,
|
||||
"shape": list(weight.shape),
|
||||
}
|
||||
|
||||
|
||||
def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]:
|
||||
"""Decode an IPC-encoded weight entry back to a named tensor.
|
||||
|
||||
Handles optional ``target_dtype`` for backward compatibility with older
|
||||
serve code that may not include it.
|
||||
"""
|
||||
wire_dtype = getattr(torch, entry["dtype"])
|
||||
weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape(
|
||||
entry["shape"]
|
||||
)
|
||||
target_dtype = entry.get("target_dtype")
|
||||
if target_dtype and target_dtype != entry["dtype"]:
|
||||
weight = weight.to(getattr(torch, target_dtype))
|
||||
return entry["name"], weight
|
||||
294
tests/test_ebft_kernels.py
Normal file
294
tests/test_ebft_kernels.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Correctness tests for fused EBFT Triton kernels."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.core.trainers.ebft.kernels import (
|
||||
fused_cosine_similarity,
|
||||
fused_diversity_penalty,
|
||||
fused_log_softmax_gather,
|
||||
fused_reinforce_loss,
|
||||
)
|
||||
|
||||
# Skip all tests if CUDA not available
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. fused_log_softmax_gather
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedLogSoftmaxGather:
|
||||
def _reference(self, logits, labels):
|
||||
"""PyTorch reference: log_softmax + gather."""
|
||||
lp = F.log_softmax(logits.float(), dim=-1)
|
||||
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
B, S, V = 2, 16, 1024
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_large_vocab(self):
|
||||
"""Test with realistic vocab size (128K)."""
|
||||
B, S, V = 1, 8, 128256
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_fp32_input(self):
|
||||
B, S, V = 2, 8, 512
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_output_is_negative(self):
|
||||
"""log_softmax values should always be <= 0."""
|
||||
B, S, V = 4, 32, 2048
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
|
||||
|
||||
def test_extreme_logits(self):
|
||||
"""Test numerical stability with very large/small logits."""
|
||||
B, S, V = 1, 4, 256
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
|
||||
logits[:, 0, :] = 1000.0 # very large
|
||||
logits[:, 1, :] = -1000.0 # very small
|
||||
logits[:, 2, 0] = 1000.0 # one hot-ish
|
||||
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
assert torch.isfinite(out).all(), "Should handle extreme values"
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with pre-flattened (N, V) input."""
|
||||
N, V = 64, 4096
|
||||
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (N,), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. fused_reinforce_loss
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedReinforceLoss:
|
||||
def _reference(self, logps, advantages, mask):
|
||||
"""PyTorch reference implementation."""
|
||||
loss_per_token = -logps * advantages
|
||||
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
N = 1024
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with (B, S) shaped inputs."""
|
||||
B, S = 4, 256
|
||||
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_all_masked(self):
|
||||
"""All-zero mask should return 0."""
|
||||
N = 512
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
|
||||
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
assert out.item() == 0.0
|
||||
|
||||
def test_all_unmasked(self):
|
||||
N = 512
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large(self):
|
||||
"""Test with realistic size (4 * 3000 tokens)."""
|
||||
N = 12000
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. fused_cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedCosineSimilarity:
|
||||
def test_basic_correctness(self):
|
||||
N, D = 64, 256
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_batched(self):
|
||||
"""Test with (B, N, NB, D) shaped input."""
|
||||
B, N, NB, D = 2, 4, 16, 512
|
||||
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_identical_vectors(self):
|
||||
"""Identical vectors should give similarity = 1."""
|
||||
N, D = 16, 128
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
|
||||
|
||||
out = fused_cosine_similarity(a, a)
|
||||
torch.testing.assert_close(
|
||||
out,
|
||||
torch.ones(N, device=DEVICE, dtype=torch.float32),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
"""Orthogonal vectors should give similarity = 0."""
|
||||
D = 128
|
||||
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
|
||||
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
|
||||
a[0, 0] = 1.0
|
||||
b[0, 1] = 1.0
|
||||
|
||||
out = fused_cosine_similarity(a, b)
|
||||
assert abs(out.item()) < 1e-5
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
"""Opposite vectors should give similarity = -1."""
|
||||
N, D = 8, 64
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
|
||||
out = fused_cosine_similarity(a, -a)
|
||||
torch.testing.assert_close(
|
||||
out,
|
||||
-torch.ones(N, device=DEVICE, dtype=torch.float32),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
def test_large_dimension(self):
|
||||
"""Test with large feature dimension (multi-layer concatenated features)."""
|
||||
N, D = 32, 4608 # 3 layers * 1536 hidden
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. fused_diversity_penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedDiversityPenalty:
|
||||
def _reference(self, embeddings):
|
||||
"""PyTorch reference: bmm + mask diagonal + mean."""
|
||||
B, N, D = embeddings.shape
|
||||
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
|
||||
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
|
||||
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
|
||||
return sims.sum(dim=-1) / (N - 1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
B, N, D = 4, 4, 256
|
||||
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_two_samples(self):
|
||||
"""With n=2, diversity = dot(a, b) for each."""
|
||||
B, D = 3, 128
|
||||
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_identical_samples(self):
|
||||
"""All identical samples should give max diversity."""
|
||||
B, N, D = 2, 4, 64
|
||||
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
|
||||
emb = vec.expand(B, N, D).contiguous()
|
||||
|
||||
out = fused_diversity_penalty(emb)
|
||||
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
|
||||
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
|
||||
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large(self):
|
||||
"""Test with realistic EBFT dimensions."""
|
||||
B, N, D = 1, 4, 4608 # multi-layer features
|
||||
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_single_sample_returns_zeros(self):
|
||||
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
|
||||
B, D = 3, 128
|
||||
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
|
||||
out = fused_diversity_penalty(emb)
|
||||
assert (out == 0).all(), "N=1 diversity should be exactly zero"
|
||||
363
tests/test_ebft_strided_structured.py
Normal file
363
tests/test_ebft_strided_structured.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for the EBFT strided structured dataset transform and data loading."""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer():
|
||||
"""Create a simple word-level tokenizer — no network access needed."""
|
||||
# Build a tiny vocab covering common test words
|
||||
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
|
||||
words = (
|
||||
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
|
||||
"x write code print test some string metadata noise ok good python "
|
||||
"sampling abc 123 def solve return this that"
|
||||
).split()
|
||||
for w in words:
|
||||
if w not in vocab:
|
||||
vocab[w] = len(vocab)
|
||||
|
||||
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
|
||||
backend.pre_tokenizer = pre_tokenizers.Whitespace()
|
||||
|
||||
tok = PreTrainedTokenizerFast(
|
||||
tokenizer_object=backend,
|
||||
bos_token="[BOS]",
|
||||
eos_token="[EOS]",
|
||||
pad_token="[PAD]",
|
||||
unk_token="[UNK]",
|
||||
)
|
||||
return tok
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg():
|
||||
return DictDefault({"sequence_len": 64})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transform_fn_and_kwargs(cfg):
|
||||
result = load_ebft("ebft_strided_structured.transform", cfg)
|
||||
assert result is not None, "Failed to load ebft_strided_structured transform"
|
||||
transform_fn, map_kwargs = result
|
||||
return transform_fn, map_kwargs
|
||||
|
||||
|
||||
class TestEBFTStridedStructuredTransform:
|
||||
"""Tests for the dataset transform function itself."""
|
||||
|
||||
def test_transform_loads(self, transform_fn_and_kwargs):
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
assert callable(transform_fn)
|
||||
assert "remove_columns" in map_kwargs
|
||||
|
||||
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
|
||||
"""Transform should signal removal of all original columns."""
|
||||
_, map_kwargs = transform_fn_and_kwargs
|
||||
assert map_kwargs["remove_columns"] == "__all__"
|
||||
|
||||
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Prompt tokens get labels=-100, completion tokens get real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert "labels" in result
|
||||
assert "attention_mask" in result
|
||||
assert "prompt_length" in result
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
seq_len = len(result["input_ids"])
|
||||
|
||||
assert seq_len == 64, "Should be padded to sequence_len"
|
||||
assert len(labels) == seq_len
|
||||
assert prompt_length > 0
|
||||
|
||||
# Prompt tokens should be masked
|
||||
for i in range(prompt_length):
|
||||
assert labels[i] == -100, f"Prompt token at {i} should be -100"
|
||||
|
||||
# At least one completion token should have a real label
|
||||
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
|
||||
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
|
||||
|
||||
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""prompt_length should be the exact boundary between -100 and real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hello world", "output": "goodbye world"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
|
||||
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
|
||||
assert labels[prompt_length] != -100, (
|
||||
"First completion token should not be masked"
|
||||
)
|
||||
|
||||
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Padding tokens should have labels=-100 and attention_mask=0."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hi", "output": "bye"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
attention_mask = result["attention_mask"]
|
||||
labels = result["labels"]
|
||||
|
||||
real_len = sum(attention_mask)
|
||||
assert real_len < 64, "Short example should have padding"
|
||||
|
||||
for i in range(real_len, 64):
|
||||
assert attention_mask[i] == 0, (
|
||||
f"Pad position {i} should have attention_mask=0"
|
||||
)
|
||||
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
|
||||
|
||||
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Long completions should be truncated to fit sequence_len."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {
|
||||
"input": "short prompt",
|
||||
"output": "x " * 500,
|
||||
}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
assert len(result["input_ids"]) == 64
|
||||
|
||||
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Transform should handle different field name conventions."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
|
||||
result = transform_fn(
|
||||
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
result = transform_fn(
|
||||
{"question": "what", "answer": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
|
||||
"""Without tokenizer, should return a prompt string."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
result = transform_fn({"input": "hello", "output": "world"})
|
||||
assert "prompt" in result
|
||||
assert result["prompt"] == "hello"
|
||||
|
||||
|
||||
class TestEBFTColumnRemoval:
|
||||
"""Tests for the __all__ column removal logic in the RL data path."""
|
||||
|
||||
def _filter_remove_columns(self, map_kwargs, dataset):
|
||||
"""Reproduce the filtering logic from rl.py _load_split."""
|
||||
if "remove_columns" in map_kwargs:
|
||||
ds_columns = dataset.column_names
|
||||
if map_kwargs["remove_columns"] == "__all__":
|
||||
map_kwargs["remove_columns"] = list(ds_columns)
|
||||
else:
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
return map_kwargs
|
||||
|
||||
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""After mapping, only tokenized columns should remain."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs) # copy
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert "input" in map_kwargs["remove_columns"]
|
||||
assert "output" in map_kwargs["remove_columns"]
|
||||
assert "extra_field" in map_kwargs["remove_columns"]
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
assert "input_ids" in mapped.column_names
|
||||
assert "labels" in mapped.column_names
|
||||
assert "prompt_length" in mapped.column_names
|
||||
assert "input" not in mapped.column_names
|
||||
assert "output" not in mapped.column_names
|
||||
assert "extra_field" not in mapped.column_names
|
||||
|
||||
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Datasets with many extra metadata columns should all be cleaned up."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"input": "write hello world",
|
||||
"output": "print hello",
|
||||
"id": "abc 123",
|
||||
"domain": "python",
|
||||
"generation_algorithm": "sampling",
|
||||
"llm_judgement": "good",
|
||||
"unit_tests": "test",
|
||||
"tests_execution_status": "ok",
|
||||
"average_test_score": 0.95,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert len(map_kwargs["remove_columns"]) == 9
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
|
||||
assert set(mapped.column_names) == expected_columns
|
||||
|
||||
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""No string-typed columns should remain (would crash the DataLoader)."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "test", "output": "test", "notes": "some string metadata"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
for col in mapped.column_names:
|
||||
feat = mapped.features[col]
|
||||
assert str(feat) != "string", (
|
||||
f"Column '{col}' is still a string — would crash DataLoader"
|
||||
)
|
||||
|
||||
def test_filter_preserves_explicit_list(self):
|
||||
"""When remove_columns is an explicit list, only existing columns are kept."""
|
||||
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
|
||||
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
|
||||
|
||||
ds_columns = ds.column_names
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
|
||||
assert map_kwargs["remove_columns"] == ["a", "b"]
|
||||
assert "missing_col" not in map_kwargs["remove_columns"]
|
||||
|
||||
|
||||
class TestMultiTurnSeparators:
|
||||
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
|
||||
|
||||
def test_multiturn_transform_splits_turns(self):
|
||||
"""Transform should store first turn as GT and remaining turns separately."""
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"sequence_len": 512})
|
||||
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
|
||||
out = fn(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Q1"},
|
||||
{"role": "assistant", "content": "A1"},
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
}
|
||||
)
|
||||
# ground_truth is only the first assistant turn
|
||||
assert out["ground_truth"] == "A1"
|
||||
# remaining_turns carries the rest for trainer-side reconstruction
|
||||
assert out["remaining_turns"] == [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
def test_multiturn_gt_reconstruction_via_chat_template(self):
|
||||
"""Trainer-side GT reconstruction should insert role markers between turns.
|
||||
|
||||
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
|
||||
GT using apply_chat_template, ensuring assistant turns are separated by
|
||||
role markers rather than concatenated as raw text.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
# Simulate the transform output
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
remaining_turns = [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(remaining_turns)
|
||||
full_gt_text = tokenizer.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
# The full GT text should contain both assistant turns with role markers
|
||||
assert "A1" in full_gt_text
|
||||
assert "A2" in full_gt_text
|
||||
# Raw concatenation "A1A2" should NOT appear — role markers separate them
|
||||
assert "A1A2" not in full_gt_text, (
|
||||
"GT reconstruction should have role markers between turns, not raw concatenation"
|
||||
)
|
||||
# The user turn Q2 should appear between A1 and A2
|
||||
a1_pos = full_gt_text.index("A1")
|
||||
a2_pos = full_gt_text.index("A2")
|
||||
q2_pos = full_gt_text.index("Q2")
|
||||
assert a1_pos < q2_pos < a2_pos, (
|
||||
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
|
||||
)
|
||||
# The GT should start with the prompt
|
||||
assert full_gt_text.startswith(prompt_text), (
|
||||
"Full GT should start with the rendered prompt"
|
||||
)
|
||||
|
||||
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
|
||||
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
# remaining_turns would be [] for single-turn prompts
|
||||
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# With empty remaining_turns, trainer falls through to raw concat
|
||||
# (trainer.py:302: gt_texts.append(prompt_text + gt))
|
||||
gt_text = prompt_text + gt
|
||||
assert gt_text.endswith("A1")
|
||||
assert prompt_text in gt_text
|
||||
158
tests/test_http_weight_sync.py
Normal file
158
tests/test_http_weight_sync.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
|
||||
|
||||
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
|
||||
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.utils.weight_serde import (
|
||||
decode_from_http,
|
||||
decode_from_ipc,
|
||||
encode_for_http,
|
||||
encode_for_ipc,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 1: trainer → serve endpoint (HTTP with base64)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHttpEncodeRoundTrip:
|
||||
"""Test encode_for_http / decode_from_http."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_round_trip_dtype(self, dtype):
|
||||
original = torch.randn(32, 64, dtype=dtype)
|
||||
entry = encode_for_http("layer.weight", original)
|
||||
name, decoded = decode_from_http(entry)
|
||||
|
||||
assert name == "layer.weight"
|
||||
assert decoded.dtype == dtype
|
||||
assert decoded.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
# bf16→fp16→bf16 loses some precision
|
||||
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_wire_format_is_fp16(self):
|
||||
"""bf16 tensors should be sent as fp16 on the wire."""
|
||||
import base64
|
||||
|
||||
original = torch.randn(8, 16, dtype=torch.bfloat16)
|
||||
entry = encode_for_http("w", original)
|
||||
raw = base64.b64decode(entry["data"])
|
||||
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
|
||||
assert len(raw) == 8 * 16 * 2
|
||||
# dtype field should preserve original dtype for reconstruction
|
||||
assert entry["dtype"] == "torch.bfloat16"
|
||||
|
||||
def test_multidimensional_shapes(self):
|
||||
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
|
||||
original = torch.randn(*shape, dtype=torch.bfloat16)
|
||||
entry = encode_for_http("w", original)
|
||||
_, decoded = decode_from_http(entry)
|
||||
assert decoded.shape == original.shape
|
||||
assert decoded.dtype == torch.bfloat16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIpcEncodeRoundTrip:
|
||||
"""Test encode_for_ipc / decode_from_ipc."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_round_trip_dtype(self, dtype):
|
||||
original = torch.randn(32, 64, dtype=dtype)
|
||||
entry = encode_for_ipc("layer.weight", original)
|
||||
name, decoded = decode_from_ipc(entry)
|
||||
|
||||
assert name == "layer.weight"
|
||||
assert decoded.dtype == dtype
|
||||
assert decoded.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_ipc_wire_is_fp16(self):
|
||||
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
|
||||
original = torch.randn(4, 8, dtype=torch.bfloat16)
|
||||
entry = encode_for_ipc("w", original)
|
||||
assert entry["dtype"] == "float16"
|
||||
assert entry["target_dtype"] == "bfloat16"
|
||||
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
|
||||
|
||||
def test_fp32_has_no_target_dtype_mismatch(self):
|
||||
original = torch.randn(4, 8, dtype=torch.float32)
|
||||
entry = encode_for_ipc("w", original)
|
||||
assert entry["dtype"] == "float32"
|
||||
assert entry["target_dtype"] == "float32"
|
||||
|
||||
def test_worker_handles_missing_target_dtype(self):
|
||||
"""Backward compat: older serve code may not send target_dtype."""
|
||||
entry = {
|
||||
"name": "w",
|
||||
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
|
||||
"dtype": "float32",
|
||||
"shape": [4, 8],
|
||||
# no target_dtype key
|
||||
}
|
||||
name, decoded = decode_from_ipc(entry)
|
||||
assert decoded.dtype == torch.float32
|
||||
assert decoded.shape == (4, 8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full pipeline: trainer → serve → worker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullPipelineRoundTrip:
|
||||
"""End-to-end: trainer → serve → worker."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_three_stage_round_trip(self, dtype):
|
||||
"""Tensor survives trainer→serve→worker with correct dtype and values."""
|
||||
original = torch.randn(16, 32, dtype=dtype)
|
||||
|
||||
# Stage 1: trainer encodes for HTTP
|
||||
http_entry = encode_for_http("model.layers.0.weight", original)
|
||||
|
||||
# Stage 2: serve decodes HTTP, re-encodes for IPC
|
||||
name, at_serve = decode_from_http(http_entry)
|
||||
ipc_entry = encode_for_ipc(name, at_serve)
|
||||
|
||||
# Stage 3: worker decodes IPC
|
||||
_, at_worker = decode_from_ipc(ipc_entry)
|
||||
|
||||
assert at_worker.dtype == dtype
|
||||
assert at_worker.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
# Two bf16→fp16→bf16 hops compound precision loss slightly
|
||||
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
|
||||
else:
|
||||
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_precision_loss_is_bounded(self):
|
||||
"""bf16→fp16→bf16 round-trip error should be small."""
|
||||
original = torch.randn(256, 256, dtype=torch.bfloat16)
|
||||
http_entry = encode_for_http("w", original)
|
||||
_, at_serve = decode_from_http(http_entry)
|
||||
ipc_entry = encode_for_ipc("w", at_serve)
|
||||
_, at_worker = decode_from_ipc(ipc_entry)
|
||||
|
||||
max_err = (at_worker.float() - original.float()).abs().max().item()
|
||||
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
|
||||
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
|
||||
|
||||
def test_bfloat16_numpy_would_crash_without_fix(self):
|
||||
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
|
||||
t = torch.randn(4, 4, dtype=torch.bfloat16)
|
||||
with pytest.raises((RuntimeError, TypeError)):
|
||||
t.numpy()
|
||||
Reference in New Issue
Block a user