From c50c4acbf42b2816cea34a4dc18067bea4fd20b8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Mar 2026 18:43:46 -0400 Subject: [PATCH] EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict --- docker/Dockerfile-cloud-uv | 1 + examples/ebft/README.md | 214 +++ examples/ebft/ebft_opencode.py | 28 + examples/ebft/ebft_pretrain.py | 31 + examples/ebft/ebft_strided_structured.py | 80 ++ .../ebft/llama-1b-ebft-opencode-novllm.yaml | 64 + examples/ebft/llama-1b-ebft-opencode.yaml | 81 ++ .../llama-1b-ebft-strided-structured.yaml | 65 + examples/ebft/llama-1b-ebft-strided.yaml | 60 + examples/ebft/llama-3b-ebft-strided-fft.yaml | 69 + examples/ebft/llama-8b-ebft-strided-fft.yaml | 58 + .../ebft/qwen35-4b-ebft-structured-async.yaml | 86 ++ examples/ebft/qwen35-4b-ebft-structured.yaml | 77 ++ examples/ebft/qwen35-9b-ebft-structured.yaml | 82 ++ src/axolotl/cli/vllm_serve.py | 23 +- src/axolotl/common/datasets.py | 2 +- src/axolotl/core/builders/rl.py | 14 +- src/axolotl/core/trainers/__init__.py | 2 + src/axolotl/core/trainers/ebft/__init__.py | 213 +++ src/axolotl/core/trainers/ebft/args.py | 133 ++ src/axolotl/core/trainers/ebft/kernels.py | 308 +++++ src/axolotl/core/trainers/ebft/rewards.py | 264 ++++ src/axolotl/core/trainers/ebft/strided.py | 1152 +++++++++++++++++ src/axolotl/core/trainers/ebft/trainer.py | 531 ++++++++ .../core/trainers/grpo/async_trainer.py | 203 ++- .../integrations/diffusion/callbacks.py | 6 +- src/axolotl/monkeypatch/trainer/trl_vllm.py | 121 +- .../prompt_strategies/ebft/__init__.py | 9 + .../ebft/ebft_chat_multiturn.py | 129 ++ .../prompt_strategies/ebft/ebft_opencode.py | 20 + .../prompt_strategies/ebft/ebft_reasoning.py | 319 +++++ .../ebft/ebft_strided_chat.py | 110 ++ .../ebft/ebft_strided_structured.py | 80 ++ src/axolotl/scripts/vllm_serve_lora.py | 77 +- src/axolotl/scripts/vllm_worker_ext.py | 52 +- src/axolotl/train.py | 8 +- src/axolotl/utils/callbacks/__init__.py | 29 +- src/axolotl/utils/callbacks/generation.py | 58 +- src/axolotl/utils/data/rl.py | 21 +- src/axolotl/utils/schemas/config.py | 121 +- src/axolotl/utils/schemas/enums.py | 1 + src/axolotl/utils/schemas/trl.py | 16 +- src/axolotl/utils/schemas/validation.py | 119 ++ src/axolotl/utils/schemas/vllm.py | 7 + src/axolotl/utils/weight_serde.py | 94 ++ tests/test_ebft_kernels.py | 294 +++++ tests/test_ebft_strided_structured.py | 363 ++++++ tests/test_http_weight_sync.py | 158 +++ 48 files changed, 5885 insertions(+), 168 deletions(-) create mode 100644 examples/ebft/README.md create mode 100644 examples/ebft/ebft_opencode.py create mode 100644 examples/ebft/ebft_pretrain.py create mode 100644 examples/ebft/ebft_strided_structured.py create mode 100644 examples/ebft/llama-1b-ebft-opencode-novllm.yaml create mode 100644 examples/ebft/llama-1b-ebft-opencode.yaml create mode 100644 examples/ebft/llama-1b-ebft-strided-structured.yaml create mode 100644 examples/ebft/llama-1b-ebft-strided.yaml create mode 100644 examples/ebft/llama-3b-ebft-strided-fft.yaml create mode 100644 examples/ebft/llama-8b-ebft-strided-fft.yaml create mode 100644 examples/ebft/qwen35-4b-ebft-structured-async.yaml create mode 100644 examples/ebft/qwen35-4b-ebft-structured.yaml create mode 100644 examples/ebft/qwen35-9b-ebft-structured.yaml create mode 100644 src/axolotl/core/trainers/ebft/__init__.py create mode 100644 src/axolotl/core/trainers/ebft/args.py create mode 100644 src/axolotl/core/trainers/ebft/kernels.py create mode 100644 src/axolotl/core/trainers/ebft/rewards.py create mode 100644 src/axolotl/core/trainers/ebft/strided.py create mode 100644 src/axolotl/core/trainers/ebft/trainer.py create mode 100644 src/axolotl/prompt_strategies/ebft/__init__.py create mode 100644 src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py create mode 100644 src/axolotl/prompt_strategies/ebft/ebft_opencode.py create mode 100644 src/axolotl/prompt_strategies/ebft/ebft_reasoning.py create mode 100644 src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py create mode 100644 src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py create mode 100644 src/axolotl/utils/weight_serde.py create mode 100644 tests/test_ebft_kernels.py create mode 100644 tests/test_ebft_strided_structured.py create mode 100644 tests/test_http_weight_sync.py diff --git a/docker/Dockerfile-cloud-uv b/docker/Dockerfile-cloud-uv index a53dd6135..2facb6fa7 100644 --- a/docker/Dockerfile-cloud-uv +++ b/docker/Dockerfile-cloud-uv @@ -22,6 +22,7 @@ RUN apt update && \ chmod 700 ~/.ssh && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ + printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /root/cloud-entrypoint.sh && \ echo 'set-option -g history-limit 5000' >> ~/.tmux.conf diff --git a/examples/ebft/README.md b/examples/ebft/README.md new file mode 100644 index 000000000..533e13652 --- /dev/null +++ b/examples/ebft/README.md @@ -0,0 +1,214 @@ +# Energy-Based Fine-Tuning (EBFT) + +EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl. + +## Overview + +EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments. + +**Key advantages over SFT:** +- Operates on model rollouts (not teacher forcing), reducing distribution shift +- Provides dense sequence-level supervision without a task-specific verifier +- Improves both downstream accuracy and validation cross-entropy simultaneously + +**Key advantages over RLVR:** +- No reward model or verifier required — works on any (prompt, completion) data +- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing) +- Maintains distributional calibration (low feature-matching loss) + +## Two Modes + +EBFT supports two modes depending on your data format: + +### Structured Mode (`mode: structured`, default) +For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation). +- Extends GRPOTrainer — uses vLLM for fast rollout generation +- RLOO advantages and clipped policy gradient from GRPO +- Feature-matching rewards replace external reward functions + +### Strided Mode (`mode: strided`) +For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode). +- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document +- No vLLM needed — generation uses custom strided attention masks +- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention) +- Compatible with gradient checkpointing via automatic dtype normalization +- This is the core EBFT algorithm from the paper (Section F) + +### Common to both modes: +- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode) +- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation +- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7) +- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation +- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss` +- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only) + +## Quick Start + +### Structured Mode (QA data + vLLM) + +```bash +# 1. Start vLLM server +python -m trl.scripts.vllm_serve \ + --model meta-llama/Llama-3.2-1B \ + --host 0.0.0.0 --port 8000 \ + --gpu-memory-utilization 0.3 + +# 2. Train +axolotl train examples/ebft/llama-1b-ebft-opencode.yaml +``` + +### Strided Mode (unstructured text) + +```bash +# No vLLM needed — strided generation is built-in +axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml +``` + +## Configuration + +### Common EBFT Settings + +```yaml +rl: ebft + +ebft: + # Feature network: which layers to extract hidden states from + # Values are fractions of total depth (0.0 = embedding, 1.0 = final layer) + feature_layers: [0.25, 0.5, 0.75] + + # How to pool per-token hidden states into sequence embeddings + # Options: "last_token" (recommended), "mean_pooling", "concat" + embed_method: last_token + + # SVD whitening — strongly recommended (paper shows largest degradation without it) + use_whitening: true + + # Reward = alignment_coef * alignment - diversity_coef * diversity + # Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized), + # diversity uses raw dot product — both are bounded after whitening. + alignment_coef: 1.0 + diversity_coef: 1.0 + + # Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1) + # 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated + ce_coef: 0.0 +``` + +### Strided Mode Settings + +```yaml +ebft: + mode: strided + stride: 8 # tokens between anchor points (paper default: 8) + context_length: 8 # context window per block (paper default: 8) + generate_max_len: 8 # tokens generated per block (paper default: 8) + n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO) + temperature: 0.6 + rl_coef: 1.0 # RL loss weight + advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce +``` + +### Structured Mode Settings (via TRL) + +```yaml +trl: + num_generations: 4 # samples per prompt + max_completion_length: 256 # max tokens to generate + temperature: 1.0 + use_vllm: true + scale_rewards: true + loss_type: grpo + epsilon: 0.2 +``` + +### Dataset Format + +**Structured mode** — QA data with prompt + ground-truth completion: +```yaml +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform +``` +Transform returns: `{"prompt": ..., "ground_truth": ...}` + +**Strided mode** — raw text tokenized to fixed length: +```yaml +datasets: + - path: sjelassi/swallow_code_20m + type: ebft_pretrain.transform +``` +Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}` + +## How It Works + +### Structured Mode +1. **Generate**: For each prompt, generate `num_generations` completions via vLLM +2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network +3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7) +4. **RLOO advantages**: subtract leave-one-out group mean +5. **Policy gradient**: clipped PPO-style loss + +### Strided Mode +1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document +2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks +3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations +4. **Per-block rewards**: + - **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2] + - **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors + - **Reward** = `alignment_coef * alignment - diversity_coef * diversity` +5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block +6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages + +### Tracked Metrics + +| Metric | Description | +|--------|-------------| +| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) | +| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) | +| `ebft/mean_reward` | alignment - diversity (should trend upward) | +| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) | +| `ebft/rl_loss` | REINFORCE policy gradient loss | +| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) | +| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) | + +## Tips and Recommendations + +### Reward coefficients +- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`. +- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales. +- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default. +- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching. + +### Feature extraction +- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction. +- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7). + +### Performance +- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation. +- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable. + +### Memory +- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory. +- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters). +- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU. +- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights. + +## Example Configs + +| Config | Mode | Model | Description | +|--------|------|-------|-------------| +| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM | +| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM | +| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA | +| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation | + +## Citation + +```bibtex +@article{jelassi2026matching, + title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models}, + author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles}, + journal={arXiv preprint arXiv:2603.12248}, + year={2026} +} +``` diff --git a/examples/ebft/ebft_opencode.py b/examples/ebft/ebft_opencode.py new file mode 100644 index 000000000..677949ba5 --- /dev/null +++ b/examples/ebft/ebft_opencode.py @@ -0,0 +1,28 @@ +""" +Dataset transform for nvidia/OpenCodeInstruct with EBFT. + +Maps the dataset's `input` (prompt) and `output` (code solution) fields +to the format expected by the EBFT trainer. +""" + + +def transform(cfg, *args, **kwargs): + def transform_fn(example, tokenizer=None): + return { + "prompt": [ + {"role": "user", "content": example["input"]}, + ], + "ground_truth": example["output"], + } + + return transform_fn, { + "remove_columns": [ + "id", + "domain", + "generation_algorithm", + "llm_judgement", + "unit_tests", + "tests_execution_status", + "average_test_score", + ] + } diff --git a/examples/ebft/ebft_pretrain.py b/examples/ebft/ebft_pretrain.py new file mode 100644 index 000000000..27a1e54b9 --- /dev/null +++ b/examples/ebft/ebft_pretrain.py @@ -0,0 +1,31 @@ +""" +Dataset transform for unstructured text data with strided EBFT. + +Tokenizes raw text into fixed-length input_ids for the strided trainer. +Sequences are padded to sequence_len for uniform batching. +""" + + +def transform(cfg, *args, **kwargs): + seq_len = cfg.sequence_len + + def transform_fn(example, tokenizer=None): + text = example.get("question", example.get("text", "")) + if tokenizer is None: + return {"prompt": text} + + encoded = tokenizer( + text, + truncation=True, + max_length=seq_len, + padding="max_length", + add_special_tokens=True, + return_tensors=None, + ) + return { + "input_ids": encoded["input_ids"], + "attention_mask": encoded["attention_mask"], + "labels": list(encoded["input_ids"]), + } + + return transform_fn, {"remove_columns": ["question", "answer"]} diff --git a/examples/ebft/ebft_strided_structured.py b/examples/ebft/ebft_strided_structured.py new file mode 100644 index 000000000..48743931a --- /dev/null +++ b/examples/ebft/ebft_strided_structured.py @@ -0,0 +1,80 @@ +""" +Dataset transform for structured (prompt, completion) data with strided EBFT. + +Tokenizes prompt and completion separately, concatenates into a single +input_ids sequence, and marks prompt tokens with labels=-100 so the +strided trainer knows where to place anchors (completion span only). + +Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct). +""" + + +def transform(cfg, *args, **kwargs): + seq_len = cfg.sequence_len + + def transform_fn(example, tokenizer=None): + # Extract prompt and completion from the example + prompt_text = example.get( + "input", example.get("prompt", example.get("question", "")) + ) + completion_text = example.get( + "output", example.get("completion", example.get("answer", "")) + ) + + if tokenizer is None: + return {"prompt": prompt_text} + + pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id + + # Tokenize prompt and completion separately + prompt_enc = tokenizer( + prompt_text, + truncation=False, + add_special_tokens=True, + return_tensors=None, + ) + completion_enc = tokenizer( + completion_text, + truncation=False, + add_special_tokens=False, + return_tensors=None, + ) + + prompt_ids = prompt_enc["input_ids"] + completion_ids = completion_enc["input_ids"] + + # Truncate to fit within seq_len (prioritize keeping prompt + some completion) + total_len = len(prompt_ids) + len(completion_ids) + if total_len > seq_len: + # Truncate completion first, then prompt if needed + max_completion = seq_len - len(prompt_ids) + if max_completion < 1: + # Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token + prompt_ids = prompt_ids[: seq_len - 1] + completion_ids = completion_ids[:1] + else: + completion_ids = completion_ids[:max_completion] + + input_ids = prompt_ids + completion_ids + prompt_length = len(prompt_ids) + + # Labels: -100 for prompt tokens, input_ids for completion tokens + labels = [-100] * prompt_length + completion_ids + + # Pad to seq_len + pad_len = seq_len - len(input_ids) + attention_mask = [1] * len(input_ids) + [0] * pad_len + labels = labels + [-100] * pad_len + input_ids = input_ids + [pad_id] * pad_len + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompt_length": prompt_length, + } + + # Signal to remove all original columns (filtered to existing ones at map time) + return transform_fn, { + "remove_columns": "__all__", + } diff --git a/examples/ebft/llama-1b-ebft-opencode-novllm.yaml b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml new file mode 100644 index 000000000..0891033f0 --- /dev/null +++ b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml @@ -0,0 +1,64 @@ +# EBFT validation config — no vLLM, uses HF generate for simplicity +# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml + +base_model: meta-llama/Llama-3.2-1B +chat_template: llama3 +rl: ebft + +ebft: + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: false + alignment_coef: 1.0 + diversity_coef: 1.0 + ce_coef: 0.0 + +trl: + num_generations: 4 + max_completion_length: 128 + temperature: 1.0 + use_vllm: false + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:1%] + +sequence_len: 512 +micro_batch_size: 2 +gradient_accumulation_steps: 2 +num_epochs: 1 +max_steps: 10 + +learning_rate: 1.0e-5 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 2 +weight_decay: 0.01 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_linear: true + +bf16: auto +flash_attention: true +gradient_checkpointing: true + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-validation + +wandb_project: ebft +wandb_run_id: +wandb_watch: +wandb_log_model: + +logging_steps: 1 +save_steps: 100 diff --git a/examples/ebft/llama-1b-ebft-opencode.yaml b/examples/ebft/llama-1b-ebft-opencode.yaml new file mode 100644 index 000000000..d0d1069d8 --- /dev/null +++ b/examples/ebft/llama-1b-ebft-opencode.yaml @@ -0,0 +1,81 @@ +# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct +# +# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026) +# https://arxiv.org/abs/2603.12248 +# +# Prerequisites: +# 1. Start vLLM server on a separate GPU: +# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \ +# --model meta-llama/Llama-3.2-1B \ +# --host 0.0.0.0 --port 8000 \ +# --gpu-memory-utilization 0.4 --dtype bfloat16 +# +# 2. Run training: +# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml + +base_model: meta-llama/Llama-3.2-1B +chat_template: llama3 + +# --- Training method --- +rl: ebft + +# --- EBFT configuration --- +ebft: + feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth + embed_method: last_token # pool to sequence embedding via last token + use_whitening: false # SVD whitening (disable for speed in small runs) + alignment_coef: 1.0 # cosine similarity with ground-truth features + diversity_coef: 1.0 # pairwise similarity penalty + ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching) + +# --- Generation settings (via TRL/GRPO infrastructure) --- +trl: + num_generations: 4 # samples per prompt for RLOO + max_completion_length: 256 # max generated tokens + temperature: 1.0 + use_vllm: true + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + +# --- Dataset --- +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:1%] # first 1% for validation runs + +# --- Training hyperparameters --- +sequence_len: 1024 +micro_batch_size: 2 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 50 + +learning_rate: 1.0e-5 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 5 +weight_decay: 0.01 + +# --- LoRA (recommended to reduce memory with frozen feature network) --- +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_linear: true + +# --- Hardware --- +bf16: auto +flash_attention: true +gradient_checkpointing: true + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-llama-1b-opencode + +# --- Logging --- +use_tensorboard: true +logging_steps: 1 +save_steps: 25 diff --git a/examples/ebft/llama-1b-ebft-strided-structured.yaml b/examples/ebft/llama-1b-ebft-strided-structured.yaml new file mode 100644 index 000000000..8ba63b64b --- /dev/null +++ b/examples/ebft/llama-1b-ebft-strided-structured.yaml @@ -0,0 +1,65 @@ +# EBFT Strided Structured Mode: For structured (prompt, completion) data +# Uses strided block-parallel generation on completion spans — no vLLM needed. +# +# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml + +base_model: meta-llama/Llama-3.2-1B +rl: ebft + +ebft: + mode: strided # strided block-parallel generation + stride: 8 # tokens between anchor points + context_length: 8 # context window per block + generate_max_len: 8 # tokens to generate per block + n_samples_per_prompt: 4 # rollouts per document + temperature: 0.6 + top_p: 1.0 + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: true + alignment_coef: 1.0 + diversity_coef: 1.0 + rl_coef: 1.0 + ce_coef: 0.03 # small CE weight for structured data + advantage_estimator: rloo + min_completion_prefix: 8 # skip anchors too close to prompt boundary + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_strided_structured.transform + split: train[:1%] + +sequence_len: 2048 +micro_batch_size: 1 +gradient_accumulation_steps: 2 +num_epochs: 1 +# max_steps: 10 + +learning_rate: 1.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 5 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_linear: true + +bf16: auto +flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime +flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true + # (full-model compile conflicts with gradient checkpointing + flex_attention) +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError) + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-strided-structured + +wandb_project: ebft +logging_steps: 1 +save_steps: 100 diff --git a/examples/ebft/llama-1b-ebft-strided.yaml b/examples/ebft/llama-1b-ebft-strided.yaml new file mode 100644 index 000000000..c9519f160 --- /dev/null +++ b/examples/ebft/llama-1b-ebft-strided.yaml @@ -0,0 +1,60 @@ +# EBFT Strided Mode: For unstructured text data (raw code, prose) +# Uses strided block-parallel generation — no vLLM needed. +# +# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml + +base_model: meta-llama/Llama-3.2-1B +rl: ebft + +ebft: + mode: strided # strided block-parallel generation + stride: 8 # tokens between anchor points + context_length: 8 # context window per block + generate_max_len: 8 # tokens to generate per block + n_samples_per_prompt: 4 # rollouts per document + temperature: 0.6 + top_p: 1.0 + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: true + alignment_coef: 1.0 + diversity_coef: 1.0 + rl_coef: 1.0 + ce_coef: 0.0 + advantage_estimator: rloo + +datasets: + - path: sjelassi/swallow_code_20m + type: ebft_pretrain.transform + split: train[:100] + +sequence_len: 256 +micro_batch_size: 1 +gradient_accumulation_steps: 2 +num_epochs: 1 +max_steps: 5 + +learning_rate: 1.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 2 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_linear: true + +bf16: auto +flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime +gradient_checkpointing: true + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-strided-validation + +wandb_project: ebft +logging_steps: 1 +save_steps: 100 diff --git a/examples/ebft/llama-3b-ebft-strided-fft.yaml b/examples/ebft/llama-3b-ebft-strided-fft.yaml new file mode 100644 index 000000000..5695efa40 --- /dev/null +++ b/examples/ebft/llama-3b-ebft-strided-fft.yaml @@ -0,0 +1,69 @@ +# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps +# Actor on GPU 0, frozen feature network on GPU 1 +# +# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml + +base_model: meta-llama/Llama-3.2-3B +rl: ebft + +ebft: + mode: strided + stride: 8 + context_length: 8 + generate_max_len: 8 + n_samples_per_prompt: 4 + temperature: 0.6 + top_p: 1.0 + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: true + alignment_coef: 1.0 + diversity_coef: 1.0 + rl_coef: 1.0 + ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate + advantage_estimator: rloo + +datasets: + - path: sjelassi/swallow_code_20m + type: ebft_pretrain.transform + split: train[:5000] + +sequence_len: 1024 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 100 + +learning_rate: 1.0e-5 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 10 +weight_decay: 0.01 + +adapter: lora +lora_r: 32 +lora_alpha: 64 +lora_dropout: 0.05 +lora_target_linear: true + +bf16: auto +torch_dtype: bfloat16 +flash_attention: false +gradient_checkpointing: true +torch_compile: true +gradient_checkpointing_kwargs: + use_reentrant: true +ddp: false +device_map: + "": 0 + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-llama3b-strided + +wandb_project: ebft +wandb_name: llama3b-strided-lora-100steps +logging_steps: 1 +save_steps: 50 diff --git a/examples/ebft/llama-8b-ebft-strided-fft.yaml b/examples/ebft/llama-8b-ebft-strided-fft.yaml new file mode 100644 index 000000000..8cf962849 --- /dev/null +++ b/examples/ebft/llama-8b-ebft-strided-fft.yaml @@ -0,0 +1,58 @@ +# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps +# Feature network is CPU-offloaded to fit in single 32GB GPU +# +# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml + +base_model: meta-llama/Llama-3.1-8B +rl: ebft + +ebft: + mode: strided + stride: 8 + context_length: 8 + generate_max_len: 8 + n_samples_per_prompt: 4 + temperature: 0.6 + top_p: 1.0 + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: true + alignment_coef: 1.0 + diversity_coef: 1.0 + rl_coef: 1.0 + ce_coef: 0.0 + advantage_estimator: rloo + +datasets: + - path: sjelassi/swallow_code_20m + type: ebft_pretrain.transform + split: train[:5000] + +sequence_len: 1024 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 100 + +learning_rate: 1.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 10 +weight_decay: 0.01 + +bf16: auto +flash_attention: false # strided EBFT uses flex_attention at runtime +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +special_tokens: + pad_token: "<|end_of_text|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-llama8b-strided + +wandb_project: ebft +wandb_name: llama8b-strided-fft-100steps +logging_steps: 1 +save_steps: 50 diff --git a/examples/ebft/qwen35-4b-ebft-structured-async.yaml b/examples/ebft/qwen35-4b-ebft-structured-async.yaml new file mode 100644 index 000000000..759a31730 --- /dev/null +++ b/examples/ebft/qwen35-4b-ebft-structured-async.yaml @@ -0,0 +1,86 @@ +# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention) +# +# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers, +# full attention every 4th layer. This tests EBFT compatibility. +# +# Prerequisites: +# 1. Start vLLM on GPU 0: +# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml +# +# 2. Run training on GPU 1: +# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ +# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml + +base_model: Qwen/Qwen3.5-4B + +rl: ebft + +ebft: + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: false + alignment_coef: 1.0 + diversity_coef: 1.0 + ce_coef: 0.0 + +trl: + num_generations: 4 + max_completion_length: 256 + temperature: 0.7 + use_vllm: true + vllm_server_host: 0.0.0.0 + vllm_server_port: 8000 + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + generation_kwargs: + stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|> + chat_template_kwargs: + enable_thinking: false + async_prefetch: true + vllm_server_timeout: 300 + +vllm: + gpu_memory_utilization: 0.5 + max_model_len: 2048 + serve_module: axolotl.scripts.vllm_serve_lora + enforce_eager: true + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:500] + +sequence_len: 1024 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 10 + +learning_rate: 5.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 3 +weight_decay: 0.01 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers +# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA +lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" + +bf16: auto +flash_attention: true +gradient_checkpointing: true + +special_tokens: + pad_token: "<|endoftext|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-qwen35-4b-structured-async + +wandb_project: ebft +logging_steps: 1 +save_steps: 50 diff --git a/examples/ebft/qwen35-4b-ebft-structured.yaml b/examples/ebft/qwen35-4b-ebft-structured.yaml new file mode 100644 index 000000000..9108e87e9 --- /dev/null +++ b/examples/ebft/qwen35-4b-ebft-structured.yaml @@ -0,0 +1,77 @@ +# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention) +# +# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers, +# full attention every 4th layer. This tests EBFT compatibility. +# +# Prerequisites: +# 1. Start vLLM on GPU 0: +# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \ +# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager +# +# 2. Run training on GPU 1: +# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ +# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml + +base_model: Qwen/Qwen3.5-4B + +rl: ebft + +ebft: + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: false + alignment_coef: 1.0 + diversity_coef: 1.0 + ce_coef: 0.0 + +trl: + num_generations: 4 + max_completion_length: 256 + temperature: 0.7 + use_vllm: true + vllm_server_host: 0.0.0.0 + vllm_server_port: 8000 + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + generation_kwargs: + stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|> + chat_template_kwargs: + enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:500] + +sequence_len: 1024 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 10 + +learning_rate: 5.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 3 +weight_decay: 0.01 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" + +bf16: auto +flash_attention: true +gradient_checkpointing: true + +special_tokens: + pad_token: "<|endoftext|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-qwen35-4b-structured + +wandb_project: ebft +logging_steps: 1 +save_steps: 50 diff --git a/examples/ebft/qwen35-9b-ebft-structured.yaml b/examples/ebft/qwen35-9b-ebft-structured.yaml new file mode 100644 index 000000000..e79fb5fbf --- /dev/null +++ b/examples/ebft/qwen35-9b-ebft-structured.yaml @@ -0,0 +1,82 @@ +# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention) +# +# Prerequisites: +# 1. Start vLLM on GPU 0: +# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml +# +# 2. Run training on GPU 1: +# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ +# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml + +base_model: Qwen/Qwen3.5-9B + +rl: ebft + +ebft: + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: false + alignment_coef: 1.0 + diversity_coef: 1.0 + ce_coef: 0.0 + +trl: + num_generations: 4 + max_completion_length: 256 + temperature: 0.7 + use_vllm: true + vllm_server_host: 0.0.0.0 + vllm_server_port: 8000 + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + generation_kwargs: + stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|> + chat_template_kwargs: + enable_thinking: false + vllm_server_timeout: 300 + +vllm: + gpu_memory_utilization: 0.7 + max_model_len: 2048 + serve_module: axolotl.scripts.vllm_serve_lora + enforce_eager: true + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:500] + +sequence_len: 1024 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +num_epochs: 1 +max_steps: 10 + +learning_rate: 3.0e-6 +optimizer: adamw_torch_fused +lr_scheduler: cosine +warmup_steps: 3 +weight_decay: 0.01 + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers +# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA +lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" + +bf16: auto +flash_attention: true +gradient_checkpointing: true + +special_tokens: + pad_token: "<|endoftext|>" + +val_set_size: 0.0 +output_dir: ./outputs/ebft-qwen35-9b-structured + +wandb_project: ebft +logging_steps: 1 +save_steps: 50 diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index 10db23878..2180a9e7f 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -38,18 +38,14 @@ def do_vllm_serve( cfg = load_cfg(config) model = cfg.base_model - # Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default + # Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve). + # We default to axolotl's serve module instead of TRL's because TRL's sends + # truncate_prompt_tokens which is unsupported in vLLM 0.17+. serve_module = cli_args.get("serve_module") or getattr( cfg.vllm, "serve_module", None ) - if ( - serve_module is None - and getattr(cfg, "trl", None) - and getattr(cfg.trl, "vllm_lora_sync", False) - ): - serve_module = "axolotl.scripts.vllm_serve_lora" if serve_module is None: - serve_module = "trl.scripts.vllm_serve" + serve_module = "axolotl.scripts.vllm_serve_lora" vllm_serve_main = __import__(serve_module, fromlist=["main"]).main tensor_parallel_size = 1 data_parallel_size = 1 @@ -79,6 +75,12 @@ def do_vllm_serve( cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False ) + cli_enforce_eager = cli_args.get("enforce_eager") + cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None) + raw_enforce_eager = ( + cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager + ) + enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False base_kwargs = dict( model=model, tensor_parallel_size=tensor_parallel_size, @@ -89,6 +91,7 @@ def do_vllm_serve( dtype=dtype, max_model_len=max_model_len, enable_prefix_caching=enable_prefix_caching, + enforce_eager=enforce_eager, ) # Use LoRAScriptArguments when serving with native LoRA support @@ -98,6 +101,10 @@ def do_vllm_serve( lora_kwargs = {} if hasattr(cfg, "lora_r") and cfg.lora_r: lora_kwargs["max_lora_rank"] = cfg.lora_r + # Disable native LoRA in vLLM if not using vllm_lora_sync + # (merged weight sync via batch_update doesn't need vLLM LoRA mode) + if not getattr(cfg.trl, "vllm_lora_sync", False): + lora_kwargs["enable_lora"] = False vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs) else: vllm_script_args = AxolotlScriptArguments( diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index c95ddb80e..b661e74c9 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -118,7 +118,7 @@ def load_preference_datasets( train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer) total_num_steps: int | None = None - if cfg.rl is not RLType.GRPO: + if cfg.rl not in {RLType.GRPO, RLType.EBFT}: total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index f7bf110cc..89d4c9ff7 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_cls = AxolotlKTOTrainer elif self.cfg.rl is RLType.SIMPO: trainer_cls = AxolotlCPOTrainer + elif self.cfg.rl is RLType.EBFT: + from axolotl.core.trainers.ebft import EBFTStrategy + + trainer_cls = EBFTStrategy.get_trainer_class(self.cfg) + trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg)) else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") @@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in [RLType.DPO, RLType.IPO]: training_args_cls = AxolotlDPOConfig training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg)) + + elif self.cfg.rl is RLType.EBFT: + from axolotl.core.trainers.ebft import EBFTStrategy + + training_args_cls = EBFTStrategy.get_training_args_class(self.cfg) + training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg)) + blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg) else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") @@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if ( self.cfg.adapter and self.peft_config - and self.cfg.rl not in (RLType.GRPO, RLType.ORPO) + and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT) ): trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 22d8b64f6..cdc09ae0a 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -4,6 +4,8 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer +from .ebft.strided import AxolotlStridedEBFTTrainer +from .ebft.trainer import AxolotlEBFTTrainer from .mamba import AxolotlMambaTrainer from .trl import ( AxolotlCPOTrainer, diff --git a/src/axolotl/core/trainers/ebft/__init__.py b/src/axolotl/core/trainers/ebft/__init__.py new file mode 100644 index 000000000..23b61fbe6 --- /dev/null +++ b/src/axolotl/core/trainers/ebft/__init__.py @@ -0,0 +1,213 @@ +"""EBFT (Energy-Based Fine-Tuning) Strategy for training + +Two modes: +- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM. +- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation. +""" + +from typing import Any + +from axolotl.core.trainers.ebft.args import ( + AxolotlAsyncEBFTConfig, + AxolotlEBFTConfig, + AxolotlStridedEBFTConfig, +) +from axolotl.utils.dict import DictDefault + + +def _get_ebft_mode(cfg: DictDefault) -> str: + """Determine EBFT mode from config.""" + if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode: + return cfg.ebft.mode + return "structured" + + +class EBFTStrategy: + """Strategy for EBFT training — dispatches between structured and strided modes.""" + + @classmethod + def get_trainer_class(cls, cfg: DictDefault | None = None): + mode = _get_ebft_mode(cfg) if cfg else "structured" + if mode == "strided": + from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer + + return AxolotlStridedEBFTTrainer + + # Structured mode: async or sync + use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False) + if use_async: + from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer + + return AxolotlAsyncEBFTTrainer + from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer + + return AxolotlEBFTTrainer + + @classmethod + def get_training_args_class(cls, cfg: DictDefault | None = None): + mode = _get_ebft_mode(cfg) if cfg else "structured" + if mode == "strided": + return AxolotlStridedEBFTConfig + + # Structured mode: async or sync config + use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False) + if use_async: + return AxolotlAsyncEBFTConfig + return AxolotlEBFTConfig + + @classmethod + def is_strided(cls, cfg: DictDefault) -> bool: + return _get_ebft_mode(cfg) == "strided" + + @classmethod + def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: + """Map axolotl YAML config fields to training args kwargs.""" + kwargs: dict[str, Any] = {} + mode = _get_ebft_mode(cfg) + + # Common EBFT fields + ebft = cfg.ebft + if ebft: + if ebft.feature_layers is not None: + kwargs["ebft_feature_layers"] = ebft.feature_layers + if ebft.embed_method is not None: + kwargs["ebft_embed_method"] = ebft.embed_method + if ebft.use_whitening is not None: + kwargs["ebft_use_whitening"] = ebft.use_whitening + if ebft.alignment_coef is not None: + kwargs["ebft_alignment_coef"] = ebft.alignment_coef + if ebft.diversity_coef is not None: + kwargs["ebft_diversity_coef"] = ebft.diversity_coef + if ebft.ce_coef is not None: + kwargs["ebft_ce_coef"] = ebft.ce_coef + if getattr(ebft, "adaptive_max_tokens", None) is not None: + kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens + if getattr(ebft, "gt_length_multiplier", None) is not None: + kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier + + if mode == "strided": + # Strided-specific fields + if ebft: + if ebft.stride is not None: + kwargs["ebft_stride"] = ebft.stride + if ebft.context_length is not None: + kwargs["ebft_context_length"] = ebft.context_length + if ebft.generate_max_len is not None: + kwargs["ebft_generate_max_len"] = ebft.generate_max_len + if ebft.n_samples_per_prompt is not None: + kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt + if ebft.temperature is not None: + kwargs["ebft_temperature"] = ebft.temperature + if ebft.top_p is not None: + kwargs["ebft_top_p"] = ebft.top_p + if ebft.rl_coef is not None: + kwargs["ebft_rl_coef"] = ebft.rl_coef + if ebft.advantage_estimator is not None: + kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator + if ebft.min_completion_prefix is not None: + kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix + else: + # Structured mode: map TRL config fields + trl = cfg.trl + if trl: + if trl.use_vllm: + kwargs["use_vllm"] = trl.use_vllm + if trl.vllm_mode: + kwargs["vllm_mode"] = trl.vllm_mode + if trl.vllm_mode == "colocate": + kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode + vllm_cfg = cfg.vllm + if vllm_cfg: + kwargs["vllm_gpu_memory_utilization"] = ( + vllm_cfg.gpu_memory_utilization + ) + kwargs["vllm_tensor_parallel_size"] = ( + vllm_cfg.tensor_parallel_size + ) + kwargs["vllm_server_host"] = trl.vllm_server_host or ( + trl.vllm.host if trl.vllm else None + ) + kwargs["vllm_server_port"] = trl.vllm_server_port or ( + trl.vllm.port if trl.vllm else None + ) + if trl.vllm_server_timeout: + kwargs["vllm_server_timeout"] = trl.vllm_server_timeout + + if trl.num_generations: + kwargs["num_generations"] = trl.num_generations + if trl.max_completion_length is not None: + kwargs["max_completion_length"] = trl.max_completion_length + if trl.temperature is not None: + kwargs["temperature"] = trl.temperature + if trl.top_p is not None: + kwargs["top_p"] = trl.top_p + if trl.top_k is not None: + kwargs["top_k"] = trl.top_k + if trl.min_p is not None: + kwargs["min_p"] = trl.min_p + if trl.num_iterations is not None: + kwargs["num_iterations"] = trl.num_iterations + if trl.epsilon is not None: + kwargs["epsilon"] = trl.epsilon + if trl.epsilon_high is not None: + kwargs["epsilon_high"] = trl.epsilon_high + if trl.scale_rewards is not None: + kwargs["scale_rewards"] = trl.scale_rewards + if trl.loss_type is not None: + kwargs["loss_type"] = trl.loss_type + if trl.mask_truncated_completions is not None: + kwargs["mask_truncated_completions"] = ( + trl.mask_truncated_completions + ) + if trl.log_completions is not None: + kwargs["log_completions"] = trl.log_completions + if trl.num_completions_to_print is not None: + kwargs["num_completions_to_print"] = trl.num_completions_to_print + if trl.sync_ref_model: + kwargs["sync_ref_model"] = trl.sync_ref_model + if trl.repetition_penalty is not None: + kwargs["repetition_penalty"] = trl.repetition_penalty + if trl.generation_kwargs is not None: + kwargs["generation_kwargs"] = trl.generation_kwargs + if trl.chat_template_kwargs is not None: + kwargs["chat_template_kwargs"] = trl.chat_template_kwargs + + # Async prefetch fields (only pass when enabled — sync config doesn't have these) + if getattr(trl, "async_prefetch", False): + kwargs["async_prefetch"] = trl.async_prefetch + if getattr(trl, "vllm_sync_interval", None) is not None: + kwargs["vllm_sync_interval"] = trl.vllm_sync_interval + if getattr(trl, "vllm_lora_sync", False): + kwargs["vllm_lora_sync"] = trl.vllm_lora_sync + + return kwargs + + @classmethod + def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: + return [] + + @classmethod + def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: + return {} + + @classmethod + def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]: + mode = _get_ebft_mode(cfg) if cfg else "structured" + if mode == "strided": + return [ + "dataset_num_proc", + "max_length", + "max_prompt_length", + "include_tokens_per_second", + "beta", + ] + return [ + "dataset_num_proc", + "max_length", + "include_tokens_per_second", + "max_prompt_length", + ] + + @classmethod + def get_collator(cls, *args, **kwargs): + return None diff --git a/src/axolotl/core/trainers/ebft/args.py b/src/axolotl/core/trainers/ebft/args.py new file mode 100644 index 000000000..4a31b6b6b --- /dev/null +++ b/src/axolotl/core/trainers/ebft/args.py @@ -0,0 +1,133 @@ +""" +EBFT-specific training arguments. + +Two config classes: +- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation) +- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation) +""" + +from dataclasses import dataclass, field + +from transformers import TrainingArguments +from trl import GRPOConfig + +from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig +from axolotl.core.training_args import AxolotlTrainingMixins + + +# -- Shared EBFT fields as a mixin -- +@dataclass +class EBFTFieldsMixin: + """Common fields shared between structured and strided EBFT configs.""" + + ebft_feature_layers: list[float] = field( + default_factory=lambda: [0.25, 0.5, 0.75], + metadata={"help": "Fractional layer depths for feature extraction"}, + ) + ebft_embed_method: str = field( + default="last_token", + metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"}, + ) + ebft_use_whitening: bool = field( + default=False, + metadata={"help": "Apply SVD whitening to feature embeddings"}, + ) + ebft_alignment_coef: float = field( + default=1.0, + metadata={"help": "Coefficient for alignment reward (cosine similarity)"}, + ) + ebft_diversity_coef: float = field( + default=1.0, + metadata={"help": "Coefficient for diversity penalty"}, + ) + ebft_ce_coef: float = field( + default=0.0, + metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"}, + ) + ebft_adaptive_max_tokens: bool = field( + default=True, + metadata={"help": "Set per-batch max_tokens based on ground-truth length"}, + ) + ebft_gt_length_multiplier: float = field( + default=1.5, + metadata={ + "help": "Multiplier for ground-truth token count when computing adaptive max_tokens" + }, + ) + + +# -- Structured mode: extends GRPOTrainer for QA data with vLLM -- +@dataclass +class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig): + """EBFT config for structured QA data — extends GRPOConfig.""" + + vllm_lora_sync: bool = field( + default=False, + metadata={ + "help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge." + }, + ) + + +# -- Async structured mode: extends FastAsyncGRPOConfig -- +@dataclass +class AxolotlAsyncEBFTConfig( + EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig +): + """EBFT config for async structured QA data — extends FastAsyncGRPOConfig. + + Includes all async fields: async_prefetch, vllm_lora_sync, + skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc. + """ + + vllm_lora_sync: bool = field( + default=False, + metadata={ + "help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge." + }, + ) + + +# -- Strided mode: extends TrainingArguments for unstructured text -- +@dataclass +class AxolotlStridedEBFTConfig( + EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments +): + """EBFT config for unstructured text with strided block-parallel generation.""" + + ebft_stride: int = field( + default=8, + metadata={"help": "Stride between anchor points (in tokens)"}, + ) + ebft_context_length: int = field( + default=8, + metadata={"help": "Context window size for each block"}, + ) + ebft_generate_max_len: int = field( + default=8, + metadata={"help": "Number of tokens to generate per block"}, + ) + ebft_n_samples_per_prompt: int = field( + default=4, + metadata={"help": "Number of independent rollouts per document"}, + ) + ebft_temperature: float = field( + default=0.6, + metadata={"help": "Sampling temperature for strided generation"}, + ) + ebft_top_p: float = field( + default=1.0, + metadata={"help": "Top-p nucleus sampling threshold"}, + ) + ebft_rl_coef: float = field( + default=1.0, + metadata={"help": "RL policy gradient loss coefficient"}, + ) + ebft_advantage_estimator: str = field( + default="rloo", + metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"}, + ) + ebft_min_completion_prefix: int = field( + default=0, + metadata={"help": "Minimum tokens into completion before placing anchors"}, + ) diff --git a/src/axolotl/core/trainers/ebft/kernels.py b/src/axolotl/core/trainers/ebft/kernels.py new file mode 100644 index 000000000..d1d35feb4 --- /dev/null +++ b/src/axolotl/core/trainers/ebft/kernels.py @@ -0,0 +1,308 @@ +""" +Fused Triton kernels for strided EBFT. + +These kernels eliminate intermediate tensor materializations that dominate +the elementwise/fill category (~40% of CUDA time in profiling). + +Kernels: + 1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization) + 2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar + 3. fused_cosine_similarity: batched cosine similarity without intermediate tensors +""" + +import torch +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# 1. Fused log_softmax + gather (selective log softmax) +# --------------------------------------------------------------------------- +# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels) +# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]] +# This avoids materializing the full (B, S, V) log_softmax output. + + +@triton.jit +def _fused_log_softmax_gather_kernel( + logits_ptr, # (B*S, V) row-major + labels_ptr, # (B*S,) int64 + output_ptr, # (B*S,) float32 + V: tl.constexpr, # vocab size + BLOCK_V: tl.constexpr, # tile width over vocab +): + """Compute log_softmax(logits)[label] for each row without materializing full output.""" + row = tl.program_id(0) + + logits_row_ptr = logits_ptr + row * V + label = tl.load(labels_ptr + row) + + # Pass 1: find max for numerical stability + max_val = -float("inf") + for v_start in range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf")) + max_val = tl.maximum(max_val, tl.max(vals, axis=0)) + + # Pass 2: compute sum(exp(x - max)) + sum_exp = 0.0 + for v_start in range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf")) + sum_exp += tl.sum(tl.exp(vals - max_val), axis=0) + + log_sum_exp = tl.log(sum_exp) + + # Gather: log_softmax[label] = logits[label] - max - log_sum_exp + target_logit = tl.load(logits_row_ptr + label) + result = target_logit - max_val - log_sum_exp + + tl.store(output_ptr + row, result) + + +def fused_log_softmax_gather( + logits: torch.Tensor, labels: torch.Tensor +) -> torch.Tensor: + """Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output. + + Args: + logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32) + labels: (B, S) or (B*S,) int64 tensor of target indices + + Returns: + (B, S) or (B*S,) float32 tensor of selected log probabilities + """ + orig_shape = logits.shape[:-1] + V = logits.shape[-1] + logits_2d = logits.reshape(-1, V).contiguous() + labels_1d = labels.reshape(-1).contiguous() + n_rows = logits_2d.shape[0] + + output = torch.empty(n_rows, device=logits.device, dtype=torch.float32) + + # Choose BLOCK_V: must be power of 2, large enough for good occupancy + BLOCK_V = min(triton.next_power_of_2(V), 65536) + + _fused_log_softmax_gather_kernel[(n_rows,)]( + logits_2d, + labels_1d, + output, + V=V, + BLOCK_V=BLOCK_V, + ) + + return output.view(orig_shape) + + +# --------------------------------------------------------------------------- +# 2. Fused masked REINFORCE loss reduction +# --------------------------------------------------------------------------- +# Instead of: (-logp * adv * mask).sum() / mask.sum() +# We do the masked multiply-accumulate in one kernel, returning (sum, count). + + +@triton.jit +def _fused_reinforce_loss_kernel( + logps_ptr, # (N,) float32 per-token log probs + advs_ptr, # (N,) float32 advantages + mask_ptr, # (N,) bool action mask + partial_sum_ptr, # (n_blocks,) partial sums + partial_cnt_ptr, # (n_blocks,) partial counts + N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + block_id = tl.program_id(0) + offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N) + valid = offsets < N + + logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0) + advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0) + m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32) + + # -logp * advantage * mask + loss = -logps * advs * m + block_sum = tl.sum(loss, axis=0) + block_cnt = tl.sum(m, axis=0) + + tl.store(partial_sum_ptr + block_id, block_sum) + tl.store(partial_cnt_ptr + block_id, block_cnt) + + +def fused_reinforce_loss( + per_token_logps: torch.Tensor, + advantages: torch.Tensor, + action_mask: torch.Tensor, +) -> torch.Tensor: + """Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum(). + + All inputs should be flat or will be flattened. Returns scalar loss tensor. + """ + logps_flat = per_token_logps.reshape(-1).contiguous() + advs_flat = advantages.reshape(-1).contiguous() + mask_flat = action_mask.reshape(-1).contiguous() + N = logps_flat.shape[0] + + BLOCK_N = 1024 + n_blocks = triton.cdiv(N, BLOCK_N) + + partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32) + partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32) + + _fused_reinforce_loss_kernel[(n_blocks,)]( + logps_flat, + advs_flat, + mask_flat, + partial_sum, + partial_cnt, + N=N, + BLOCK_N=BLOCK_N, + ) + + total_sum = partial_sum.sum() + total_cnt = partial_cnt.sum().clamp(min=1) + return total_sum / total_cnt + + +# --------------------------------------------------------------------------- +# 3. Fused cosine similarity (batched, for EBFT rewards) +# --------------------------------------------------------------------------- +# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots, +# we fuse the dot product, norm computation, and division into one kernel. + + +@triton.jit +def _fused_cosine_sim_kernel( + a_ptr, # (N, D) first set of vectors + b_ptr, # (N, D) second set of vectors + out_ptr, # (N,) cosine similarities + D: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + a_row_ptr = a_ptr + row * D + b_row_ptr = b_ptr + row * D + + dot = 0.0 + norm_a = 0.0 + norm_b = 0.0 + + for d_start in range(0, D, BLOCK_D): + d_offsets = d_start + tl.arange(0, BLOCK_D) + mask = d_offsets < D + a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32) + b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32) + + dot += tl.sum(a_vals * b_vals, axis=0) + norm_a += tl.sum(a_vals * a_vals, axis=0) + norm_b += tl.sum(b_vals * b_vals, axis=0) + + denom = tl.sqrt(norm_a) * tl.sqrt(norm_b) + denom = tl.where(denom > 1e-8, denom, 1e-8) + result = dot / denom + + tl.store(out_ptr + row, result) + + +def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Compute cosine similarity along the last dimension. + + Args: + a, b: (..., D) tensors of the same shape + + Returns: + (...,) tensor of cosine similarities + """ + orig_shape = a.shape[:-1] + D = a.shape[-1] + a_2d = a.reshape(-1, D).contiguous() + b_2d = b.reshape(-1, D).contiguous() + N = a_2d.shape[0] + + output = torch.empty(N, device=a.device, dtype=torch.float32) + + BLOCK_D = min(triton.next_power_of_2(D), 4096) + + _fused_cosine_sim_kernel[(N,)]( + a_2d, + b_2d, + output, + D=D, + BLOCK_D=BLOCK_D, + ) + + return output.view(orig_shape) + + +# --------------------------------------------------------------------------- +# 4. Fused pairwise diversity penalty +# --------------------------------------------------------------------------- +# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1) +# We compute the pairwise dot products and exclusion in one kernel. + + +@triton.jit +def _fused_diversity_kernel( + emb_ptr, # (B, N, D) embeddings, row-major + out_ptr, # (B, N) diversity penalties + N: tl.constexpr, # n_samples + D: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """For each (b, i), compute mean dot product to all j != i.""" + b = tl.program_id(0) + i = tl.program_id(1) + + # Pointer to emb[b, i, :] + emb_bi_ptr = emb_ptr + (b * N + i) * D + + total_sim = 0.0 + for j in range(N): + emb_bj_ptr = emb_ptr + (b * N + j) * D + + dot = 0.0 + for d_start in range(0, D, BLOCK_D): + d_offsets = d_start + tl.arange(0, BLOCK_D) + d_mask = d_offsets < D + a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to( + tl.float32 + ) + b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to( + tl.float32 + ) + dot += tl.sum(a_vals * b_vals, axis=0) + + # Exclude self-similarity (j == i) + is_other = j != i + total_sim += dot * is_other + + result = total_sim / (N - 1) + tl.store(out_ptr + b * N + i, result) + + +def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor: + """Compute mean pairwise dot product (excluding self) per sample. + + Args: + embeddings: (B, N, D) tensor where N is n_samples + + Returns: + (B, N) tensor of diversity penalties + """ + B, N, D = embeddings.shape + embeddings = embeddings.contiguous() + output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32) + if N <= 1: + return output # diversity is 0 when there's only one sample + + BLOCK_D = min(triton.next_power_of_2(D), 4096) + + _fused_diversity_kernel[(B, N)]( + embeddings, + output, + N=N, + D=D, + BLOCK_D=BLOCK_D, + ) + + return output diff --git a/src/axolotl/core/trainers/ebft/rewards.py b/src/axolotl/core/trainers/ebft/rewards.py new file mode 100644 index 000000000..993650bbe --- /dev/null +++ b/src/axolotl/core/trainers/ebft/rewards.py @@ -0,0 +1,264 @@ +""" +Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT). + +Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py +Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models" + (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@torch.no_grad() +def extract_hidden_states( + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + layer_indices: list[int], + batch_size: int | None = None, +) -> torch.Tensor: + """ + Forward pass through model, extracting and concatenating hidden states + at specified layer indices. + + Args: + model: The frozen feature network + input_ids: (B, S) token ids + attention_mask: (B, S) attention mask + layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model) + batch_size: If set, process in chunks to reduce peak memory + + Returns: + Concatenated hidden states: (B, S, num_layers * H) + """ + if batch_size is None: + batch_size = input_ids.shape[0] + + # Use the inner transformer body (skips lm_head) when available. + # This avoids the expensive hidden_dim × vocab_size matmul whose + # output (logits) is never used — only hidden_states are needed. + body = getattr(model, "model", None) + if body is not None and hasattr(body, "forward"): + forward_model = body + else: + forward_model = model + + all_features = [] + for i in range(0, input_ids.shape[0], batch_size): + chunk_ids = input_ids[i : i + batch_size] + chunk_mask = attention_mask[i : i + batch_size] + + outputs = forward_model( + chunk_ids, + attention_mask=chunk_mask, + output_hidden_states=True, + return_dict=True, + ) + + # hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H) + # index 0 is the embedding layer output + hidden_states = outputs.hidden_states + chunk_features = [] + for idx in layer_indices: + chunk_features.append(hidden_states[idx]) + + # Concatenate across feature dimension: (B, S, num_layers * H) + all_features.append(torch.cat(chunk_features, dim=-1)) + + return torch.cat(all_features, dim=0) + + +def apply_embed_method( + hidden_states: torch.Tensor, + method: str, + attention_mask: torch.Tensor | None = None, + prompt_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Pool per-token hidden states into per-sequence embeddings. + + Args: + hidden_states: (B, S, D) concatenated hidden states + method: One of "last_token", "mean_pooling", "completion_mean", "concat" + attention_mask: (B, S) mask for mean pooling + prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean) + + Returns: + Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean, + (B, 3*D) for concat + """ + if method == "last_token": + if attention_mask is not None: + # Find last non-padding position per sample + last_idx = attention_mask.sum(dim=1).long() - 1 # (B,) + return hidden_states[torch.arange(hidden_states.shape[0]), last_idx] + return hidden_states[:, -1, :] + + if method == "mean_pooling": + if attention_mask is not None: + mask = attention_mask.unsqueeze(-1).float() # (B, S, 1) + return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) + return hidden_states.mean(dim=1) + + if method == "completion_mean": + # Mean pool over completion tokens only (exclude prompt) + if prompt_lengths is None: + raise ValueError("completion_mean requires prompt_lengths") + B, S, _ = hidden_states.shape + positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S) + comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S) + if attention_mask is not None: + comp_mask = comp_mask & attention_mask.bool() + mask = comp_mask.unsqueeze(-1).float() # (B, S, 1) + return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) + + if method == "concat": + B, S, D = hidden_states.shape + if attention_mask is not None: + valid_lens = attention_mask.sum(dim=1).long() # (B,) + else: + valid_lens = torch.full( + (B,), S, device=hidden_states.device, dtype=torch.long + ) + # Compute quartile positions relative to valid length per sample + # First valid position index for each sample (handles right-padding) + q1 = (valid_lens // 4).clamp(min=0, max=S - 1) + q2 = (valid_lens // 2).clamp(min=0, max=S - 1) + q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1) + batch_idx = torch.arange(B, device=hidden_states.device) + return torch.cat( + [ + hidden_states[batch_idx, q1], + hidden_states[batch_idx, q2], + hidden_states[batch_idx, q3], + ], + dim=-1, + ) + + raise ValueError(f"Unknown embed_method: {method}") + + +@torch.no_grad() +def get_alignment_rewards( + gen_embedding: torch.Tensor, + gt_embedding: torch.Tensor, +) -> torch.Tensor: + """ + Compute alignment reward as cosine similarity between generated + and ground-truth feature embeddings. + + Args: + gen_embedding: (B, D) generated sequence embeddings + gt_embedding: (B, D) ground-truth sequence embeddings + If num_generations > 1, gt_embedding should be repeated + to match gen_embedding's batch dim. + + Returns: + Alignment rewards: (B,) cosine similarities in [-1, 1] + """ + return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1) + + +@torch.no_grad() +def get_diversity_rewards( + gen_embedding: torch.Tensor, + num_generations: int, +) -> torch.Tensor: + """ + Compute diversity penalty as mean pairwise dot-product similarity + between samples from the same prompt (excluding self-similarity). + + Args: + gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations + num_generations: Number of generations per prompt + + Returns: + Diversity penalties: (B,) mean similarity to other samples from same prompt + """ + if num_generations <= 1: + return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device) + + num_prompts = gen_embedding.shape[0] // num_generations + + # Reshape to (num_prompts, num_generations, D) + reshaped = gen_embedding.view(num_prompts, num_generations, -1) + + # Pairwise dot products within each group: (num_prompts, num_generations, num_generations) + sims = torch.bmm(reshaped, reshaped.transpose(1, 2)) + + # Zero out self-similarity (diagonal) + eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool) + sims = sims.masked_fill(eye.unsqueeze(0), 0.0) + + # Mean similarity to other samples: (num_prompts, num_generations) + diversity = sims.sum(dim=-1) / (num_generations - 1) + + # Flatten back to (B,) + return diversity.view(-1) + + +def whiten_embeddings_batched( + phi: torch.Tensor, + phi_gt: torch.Tensor, + whiten_tol: float = 1e-5, + normalize: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Whiten generated embeddings using SVD, then apply same transform to ground-truth. + + Whitening decorrelates feature dimensions so no single direction dominates + the feature-matching loss. Uses pseudo-inverse for rank-deficient cases. + + Note: Singular values scale with sqrt(B), so reward magnitudes are + batch-size dependent. This is acceptable because B = n_samples_per_prompt + which is fixed during training (typically 2-4). + + Args: + phi: (B, D) generated embeddings (used to estimate covariance) + phi_gt: (B, D) ground-truth embeddings + whiten_tol: Tolerance for singular value cutoff + normalize: If True, L2-normalize after whitening + + Returns: + Whitened (phi, phi_gt) tuple, each (B, D) + """ + phi_f = phi.float() + phi_gt_f = phi_gt.float() + + # Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D) + try: + U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False) + except torch._C._LinAlgError: + # Fallback: add small noise + noise = 1e-6 * phi_f.abs().mean() + try: + U, S, _ = torch.linalg.svd( + (phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0), + full_matrices=False, + ) + except torch._C._LinAlgError: + if normalize: + return ( + F.normalize(phi, p=2, dim=-1), + F.normalize(phi_gt, p=2, dim=-1), + ) + return phi, phi_gt + + U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),) + + # Safe inverse of singular values + s_max = S.max() + inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S)) + + # W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D) + W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D) + phi_w = (phi_f @ W).to(phi.dtype) # (B, D) + phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D) + + if normalize: + phi_w = F.normalize(phi_w, p=2, dim=-1) + phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1) + + return phi_w, phi_gt_w diff --git a/src/axolotl/core/trainers/ebft/strided.py b/src/axolotl/core/trainers/ebft/strided.py new file mode 100644 index 000000000..5cfc5b99b --- /dev/null +++ b/src/axolotl/core/trainers/ebft/strided.py @@ -0,0 +1,1152 @@ +""" +Strided block-parallel EBFT trainer for unstructured text data. + +This trainer implements the full EBFT algorithm from the paper, including +strided block-parallel generation where multiple short rollouts are generated +at different anchor points within a single document. This is essential for +training on raw text data (code, prose, etc.) without prompt/completion splits. + +Uses torch flex_attention with a compiled block mask for efficient strided +attention patterns. Falls back to eager attention with dense 4D masks when +flex_attention is not available. + +Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models" + (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248 +""" + +import contextlib +import copy + +import torch +import torch.nn.functional as F +from transformers import Trainer + +from axolotl.core.trainers.ebft.kernels import ( # noqa: F401 — available for future use + fused_cosine_similarity, + fused_diversity_penalty, + fused_log_softmax_gather, + fused_reinforce_loss, +) +from axolotl.core.trainers.ebft.rewards import ( + whiten_embeddings_batched, +) +from axolotl.core.trainers.mixins import ( + DistributedParallelMixin, + RngLoaderMixin, + SchedulerMixin, +) +from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +# Check flex_attention availability +_FLEX_ATTENTION_AVAILABLE = False +try: + from torch.nn.attention.flex_attention import ( + create_block_mask, + ) + + _FLEX_ATTENTION_AVAILABLE = True +except ImportError: + pass + + +def _patch_flex_attention_dtype(): + """ + Patch HF's flex_attention_forward to cast q/k/v to a uniform dtype. + + This fixes the incompatibility between flex_attention and gradient + checkpointing, where recomputation can produce q/k in float32 while + v stays in bfloat16. flex_attention requires all three to match. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + original_fn = ALL_ATTENTION_FUNCTIONS.get("flex_attention") + if original_fn is None: + return + + def patched_flex_attention_forward( + module, query, key, value, attention_mask, **kwargs + ): + # Cast q/k/v to the same dtype (use value's dtype as reference, + # since model weights stay in the original dtype) + target_dtype = value.dtype + if query.dtype != target_dtype: + query = query.to(target_dtype) + if key.dtype != target_dtype: + key = key.to(target_dtype) + return original_fn(module, query, key, value, attention_mask, **kwargs) + + ALL_ATTENTION_FUNCTIONS["flex_attention"] = patched_flex_attention_forward + + +@contextlib.contextmanager +def override_attn_implementation(model, implementation: str): + """Temporarily override a model's attention implementation. + + Useful for forcing eager attention during generation (where sequence + lengths change each step, causing dynamo recompiles) while keeping + flex_attention for the fixed-size training forward pass. + + Usage:: + + with override_attn_implementation(model, "eager"): + output = model(input_ids, attention_mask=dense_4d_mask, ...) + """ + config = getattr(model, "config", None) + if config is None or not hasattr(config, "_attn_implementation"): + yield + return + + saved = config._attn_implementation + config._attn_implementation = implementation + try: + yield + finally: + config._attn_implementation = saved + + +# --------------------------------------------------------------------------- +# Strided attention mask builders +# --------------------------------------------------------------------------- + + +def _strided_mask_mod( + b, + h, + q_idx, + kv_idx, + prompt_length, + context_length, + max_generation_length, + stride, + num_blocks, +): + """ + Mask mod function for flex_attention's create_block_mask. + + Defines the strided block-parallel attention pattern: + - Prompt region: standard causal + - Generated region: each block attends to its context window + same-block predecessors + """ + # --- Prompt region: standard causal --- + is_prompt_q = q_idx < prompt_length + is_prompt_kv = kv_idx < prompt_length + prompt_causal = is_prompt_q & is_prompt_kv & (q_idx >= kv_idx) + + # --- Generated region --- + is_gen_q = ~is_prompt_q + # Which generation step and block does this query token belong to? + gen_offset = q_idx - prompt_length + gen_step = gen_offset // num_blocks + block_idx = gen_offset % num_blocks + + # Context window end for this block. + # Note: if prompt_length < max_generation_length, context_end clamps to 0 for all + # blocks. This is safe because compute_loss guards with num_blocks <= 0 → zero loss. + context_end = torch.clamp( + block_idx * stride + context_length, + max=prompt_length - max_generation_length, + ) + + # Rule 1: Generated token attends to its context window in the prompt + in_context = is_gen_q & is_prompt_kv & (kv_idx < context_end) + + # Rule 2: Self-attention + is_self = q_idx == kv_idx + + # Rule 3: Attend to earlier tokens in the SAME block (same block_idx, earlier gen_step) + is_gen_kv = ~is_prompt_kv + kv_gen_offset = kv_idx - prompt_length + kv_gen_step = kv_gen_offset // num_blocks + kv_block_idx = kv_gen_offset % num_blocks + same_block_prev = ( + is_gen_q & is_gen_kv & (kv_block_idx == block_idx) & (kv_gen_step < gen_step) + ) + + return prompt_causal | in_context | is_self | same_block_prev + + +def create_strided_block_mask( + prompt_length: int, + context_length: int, + max_generation_length: int, + stride: int, + num_blocks: int, + full_sequence_length: int, + batch_size: int, + num_heads: int, + device: torch.device, +): + """ + Create a BlockMask for flex_attention using the strided EBFT pattern. + + Returns a BlockMask that can be passed directly to model.forward() + when using attn_implementation="flex_attention". + + Parameters that vary across training steps (context_length, num_blocks) + are captured as tensors so torch.compile/dynamo treats them as dynamic + values rather than guarding on literal int values (which causes recompiles). + """ + # Wrap ALL mask params as 0-d tensors to prevent dynamo from guarding + # on their int values. Without this, each new anchor_offset or num_blocks + # triggers a recompile until the limit is hit → unfused fallback → OOM. + _prompt_length = torch.tensor(prompt_length, device=device) + _context_length = torch.tensor(context_length, device=device) + _max_gen_len = torch.tensor(max_generation_length, device=device) + _stride = torch.tensor(stride, device=device) + _num_blocks = torch.tensor(num_blocks, device=device) + + def mask_mod(b, h, q_idx, kv_idx): + return _strided_mask_mod( + b, + h, + q_idx, + kv_idx, + prompt_length=_prompt_length, + context_length=_context_length, + max_generation_length=_max_gen_len, + stride=_stride, + num_blocks=_num_blocks, + ) + + block_mask = create_block_mask( + mask_mod, + B=batch_size, + H=None, # broadcast across heads + Q_LEN=full_sequence_length, + KV_LEN=full_sequence_length, + device=device, + ) + return block_mask + + +def build_strided_position_ids( + full_sequence_length: int, + prompt_length: int, + context_length: int, + generation_step: int, + stride: int, + num_blocks: int, + device: torch.device, + batch_size: int = 1, +): + """Build position IDs for strided generation (shared between flex and eager modes).""" + position_ids = torch.empty( + (batch_size, full_sequence_length), dtype=torch.long, device=device + ) + position_ids[:, :prompt_length] = torch.arange(prompt_length, device=device) + + block_starting_positions = ( + torch.arange(num_blocks, device=device) * stride + context_length + ) + for gen_step in range(generation_step): + start = prompt_length + gen_step * num_blocks + end = start + num_blocks + position_ids[:, start:end] = block_starting_positions + gen_step + + return position_ids + + +def build_strided_dense_mask_and_positions( + full_sequence_length: int, + prompt_length: int, + context_length: int, + generation_step: int, + max_generation_length: int, + stride: int, + num_blocks: int, + device: torch.device, + batch_size: int = 1, + dtype: torch.dtype = torch.bfloat16, +): + """Build dense 4D attention mask (eager fallback) + position IDs.""" + min_value = torch.finfo(dtype).min + attention_mask = torch.full( + (batch_size, 1, full_sequence_length, full_sequence_length), + min_value, + dtype=dtype, + device=device, + ) + + causal_mask = torch.tril( + torch.ones((prompt_length, prompt_length), dtype=torch.bool, device=device) + ) + attention_mask[:, :, :prompt_length, :prompt_length].masked_fill_( + causal_mask.view(1, 1, prompt_length, prompt_length), 0.0 + ) + + for gen_step in range(generation_step): + for block_idx in range(num_blocks): + gen_pos = prompt_length + gen_step * num_blocks + block_idx + context_end = min( + block_idx * stride + context_length, + prompt_length - max_generation_length, + ) + attention_mask[:, 0, gen_pos, :context_end] = 0.0 + attention_mask[:, 0, gen_pos, gen_pos] = 0.0 + if gen_step > 0: + for prev_s in range(gen_step): + prev_pos = prompt_length + prev_s * num_blocks + block_idx + attention_mask[:, 0, gen_pos, prev_pos] = 0.0 + + position_ids = build_strided_position_ids( + full_sequence_length, + prompt_length, + context_length, + generation_step, + stride, + num_blocks, + device, + batch_size, + ) + return attention_mask, position_ids + + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- + + +class AxolotlStridedEBFTTrainer( + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + Trainer, +): + """ + Strided block-parallel EBFT trainer for unstructured text data. + + Takes full text documents (no prompt/completion split needed), generates + short rollouts at multiple anchor points via strided attention, and trains + with feature-matching rewards. + + When flex_attention is available (torch >= 2.5), uses compiled block masks + for efficient fused attention kernels. Otherwise falls back to eager + attention with dense 4D masks. + """ + + _tag_names = ["ebft", "strided", "axolotl"] + + def __init__(self, model, args, train_dataset, **kwargs): + super().__init__(model=model, args=args, train_dataset=train_dataset, **kwargs) + + # EBFT config + self.ebft_stride = getattr(args, "ebft_stride", 8) + self.ebft_context_length = getattr(args, "ebft_context_length", 8) + self.ebft_generate_max_len = getattr(args, "ebft_generate_max_len", 8) + self.ebft_n_samples = getattr(args, "ebft_n_samples_per_prompt", 4) + self.ebft_temperature = getattr(args, "ebft_temperature", 0.6) + self.ebft_top_p = getattr(args, "ebft_top_p", 1.0) + self.ebft_alignment_coef = getattr(args, "ebft_alignment_coef", 1.0) + self.ebft_diversity_coef = getattr(args, "ebft_diversity_coef", 1.0) + self.ebft_rl_coef = getattr(args, "ebft_rl_coef", 1.0) + self.ebft_ce_coef = getattr(args, "ebft_ce_coef", 0.0) + self.ebft_use_whitening = getattr(args, "ebft_use_whitening", False) + self.ebft_advantage_estimator = getattr( + args, "ebft_advantage_estimator", "rloo" + ) + self.ebft_min_completion_prefix = getattr(args, "ebft_min_completion_prefix", 0) + + # Validate config combinations + if self.ebft_use_whitening and self.ebft_diversity_coef > 0: + LOG.info( + "ebft: whitening + diversity enabled. Per paper Variant (i) (eq 49): " + "alignment uses cosine similarity (normalized), diversity uses raw dot product. " + "Both are bounded after whitening." + ) + if self.ebft_n_samples == 1 and self.ebft_diversity_coef > 0: + LOG.warning( + "ebft.n_samples_per_prompt=1 with diversity_coef > 0: diversity penalty requires " + "multiple samples. Setting diversity_coef to 0." + ) + self.ebft_diversity_coef = 0.0 + if self.ebft_n_samples == 1 and self.ebft_advantage_estimator == "rloo": + LOG.warning( + "ebft.n_samples_per_prompt=1 with advantage_estimator='rloo': RLOO requires " + "multiple samples for baseline. Falling back to 'reinforce'." + ) + self.ebft_advantage_estimator = "reinforce" + + # Feature network config + feature_layers_frac = getattr(args, "ebft_feature_layers", [0.25, 0.5, 0.75]) + embed_method = getattr(args, "ebft_embed_method", "last_token") + self.ebft_embed_method = embed_method + + # Attention implementation selection + unwrapped = self.accelerator.unwrap_model(self.model) + self.use_flex_attention = ( + _FLEX_ATTENTION_AVAILABLE and torch.cuda.is_available() + ) + + if self.use_flex_attention: + _patch_flex_attention_dtype() + LOG.info("Using flex_attention for strided EBFT (compiled block masks)") + if hasattr(unwrapped.config, "_attn_implementation"): + unwrapped.config._attn_implementation = "flex_attention" + self._num_heads = unwrapped.config.num_attention_heads + else: + LOG.info("Using eager attention for strided EBFT (dense 4D masks)") + if hasattr(unwrapped.config, "_attn_implementation"): + unwrapped.config._attn_implementation = "eager" + self._num_heads = None + + # Feature network setup: either share weights with actor (PEFT models) + # or deepcopy (full-parameter models / multi-GPU). + first_param = next(unwrapped.parameters()) + original_device = first_param.device + actor_gpu = ( + original_device.index + if (original_device.type == "cuda" and original_device.index is not None) + else 0 + ) + visible_gpus = torch.cuda.device_count() + + # Check if we can share weights (PEFT model on single GPU) + from peft import PeftModel + + self._share_feature_weights = ( + isinstance(unwrapped, PeftModel) + and visible_gpus == 1 + and original_device.type != "meta" + ) + + if self._share_feature_weights: + # Share weights: use actor's base model with adapters disabled for + # feature extraction. Saves ~2.5 GB (no deepcopy of base weights). + self.feature_network = None # no separate network + self._feature_device = torch.device(f"cuda:{actor_gpu}") + self._feature_use_flex = self.use_flex_attention + LOG.info( + "Feature network shares actor weights (PEFT disable_adapter). " + f"Saving {sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9:.1f} GB" + ) + elif visible_gpus > 1 and original_device.type != "meta": + # Multi-GPU: deepcopy to a separate device + self.feature_network = copy.deepcopy(unwrapped) + self.feature_network.to(dtype=torch.bfloat16) + self._feature_device = torch.device( + f"cuda:{(actor_gpu + 1) % visible_gpus}" + ) + LOG.info(f"Creating frozen feature network on {self._feature_device}...") + self.feature_network.to(device=self._feature_device) + if _FLEX_ATTENTION_AVAILABLE and self._feature_device.type == "cuda": + if hasattr(self.feature_network.config, "_attn_implementation"): + self.feature_network.config._attn_implementation = "flex_attention" + self._feature_use_flex = True + LOG.info("Feature network using flex_attention") + else: + if hasattr(self.feature_network.config, "_attn_implementation"): + self.feature_network.config._attn_implementation = "eager" + self._feature_use_flex = False + for param in self.feature_network.parameters(): + param.requires_grad = False + self.feature_network.eval() + elif original_device.type == "meta": + # FSDP2 with cpu_ram_efficient_loading + from transformers import AutoModelForCausalLM + + feature_model_name = ( + getattr(args, "model_name_or_path", None) + or unwrapped.config._name_or_path + ) + self.feature_network = AutoModelForCausalLM.from_pretrained( + feature_model_name, + torch_dtype=torch.bfloat16, + attn_implementation="eager", + ) + self._feature_device = torch.device(f"cuda:{actor_gpu}") + self.feature_network.to(device=self._feature_device) + self._feature_use_flex = False + for param in self.feature_network.parameters(): + param.requires_grad = False + self.feature_network.eval() + LOG.warning("Feature network loaded from pretrained (meta device)") + else: + # Single-GPU, non-PEFT: deepcopy on same device + self.feature_network = copy.deepcopy(unwrapped) + self.feature_network.to(dtype=torch.bfloat16) + self._feature_device = torch.device(f"cuda:{actor_gpu}") + self.feature_network.to(device=self._feature_device) + if _FLEX_ATTENTION_AVAILABLE: + if hasattr(self.feature_network.config, "_attn_implementation"): + self.feature_network.config._attn_implementation = "flex_attention" + self._feature_use_flex = True + else: + if hasattr(self.feature_network.config, "_attn_implementation"): + self.feature_network.config._attn_implementation = "eager" + self._feature_use_flex = False + for param in self.feature_network.parameters(): + param.requires_grad = False + self.feature_network.eval() + LOG.info( + f"Created frozen feature network (deepcopy) on {self._feature_device}" + ) + + num_layers = unwrapped.config.num_hidden_layers + self.feature_layer_indices = [ + int(frac * num_layers) for frac in feature_layers_frac + ] + LOG.info( + f"Strided EBFT: layers={self.feature_layer_indices}, " + f"stride={self.ebft_stride}, ctx={self.ebft_context_length}, " + f"gen_len={self.ebft_generate_max_len}, n_samples={self.ebft_n_samples}, " + f"embed={embed_method}, flex_attn={self.use_flex_attention}, " + f"min_completion_prefix={self.ebft_min_completion_prefix}" + ) + + def _build_strided_mask( + self, + full_seq_len, + seq_len, + generation_step, + num_blocks, + batch_size, + device, + dtype, + anchor_offset=None, + ): + """Build strided attention mask + position IDs using flex or eager. + + Args: + anchor_offset: Position where anchors start. For unstructured data this + equals context_length; for structured data it equals + max(prompt_length + min_completion_prefix, context_length). + Defaults to self.ebft_context_length if not provided. + """ + if anchor_offset is None: + anchor_offset = self.ebft_context_length + + pos_ids = build_strided_position_ids( + full_seq_len, + seq_len, + anchor_offset, + generation_step, + self.ebft_stride, + num_blocks, + device, + batch_size, + ) + + if self.use_flex_attention: + block_mask = create_strided_block_mask( + prompt_length=seq_len, + context_length=anchor_offset, + max_generation_length=self.ebft_generate_max_len, + stride=self.ebft_stride, + num_blocks=num_blocks, + full_sequence_length=full_seq_len, + batch_size=batch_size, + num_heads=self._num_heads, + device=device, + ) + return block_mask, pos_ids + + dense_mask, pos_ids = build_strided_dense_mask_and_positions( + full_sequence_length=full_seq_len, + prompt_length=seq_len, + context_length=anchor_offset, + generation_step=generation_step, + max_generation_length=self.ebft_generate_max_len, + stride=self.ebft_stride, + num_blocks=num_blocks, + device=device, + batch_size=batch_size, + dtype=dtype, + ) + return dense_mask, pos_ids + + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): + """ + Full strided EBFT training step. + + 1. Take tokenized documents from inputs + 2. Generate n_samples short rollouts at strided anchor points + 3. Extract features from frozen network for both generated and GT blocks + 4. Compute alignment/diversity rewards per block + 5. Compute RLOO advantages + 6. Policy gradient loss on the strided forward pass + + Supports both unstructured text (no prompt/completion split) and + structured data (prompt + completion with labels masking). For structured + data, anchors are placed only within the completion span. + """ + outputs = None + device = next(model.parameters()).device + input_ids = inputs["input_ids"].to(device) # (B, seq_len) + B, seq_len = input_ids.shape + + stride = self.ebft_stride + ctx_len = self.ebft_context_length + gen_len = self.ebft_generate_max_len + n_samples = self.ebft_n_samples + + # --- Detect structured data and compute anchor_offset --- + # For structured data, anchors must start within the completion span. + # anchor_offset replaces ctx_len as the starting position for anchors. + is_structured = False + if "prompt_length" in inputs: + # Explicit prompt_length from dataset transform + prompt_lengths = inputs["prompt_length"].to(device) # (B,) + is_structured = True + elif "labels" in inputs: + # Derive prompt_length from labels: first position where labels != -100 + labels = inputs["labels"].to(device) + non_masked = labels != -100 + # prompt_length = index of first non-masked token (or seq_len if all masked) + has_completion = non_masked.any(dim=1) + prompt_lengths = torch.where( + has_completion, + non_masked.float().argmax(dim=1), + torch.tensor(seq_len, device=device, dtype=torch.float), + ).long() + is_structured = prompt_lengths.min().item() > 0 + + if is_structured: + # Use max prompt_length across batch for uniform anchor_offset + max_prompt_len = prompt_lengths.max().item() + anchor_offset = max( + max_prompt_len + self.ebft_min_completion_prefix, ctx_len + ) + else: + anchor_offset = ctx_len + + num_blocks = (seq_len - gen_len - anchor_offset) // stride + 1 + if num_blocks <= 0: + LOG.warning( + f"Sequence too short for strided EBFT: seq_len={seq_len}, " + f"anchor_offset={anchor_offset}, " + f"need >= {gen_len + anchor_offset + stride}. Returning zero loss." + ) + dummy_loss = input_ids.float().mean() * 0.0 + return (dummy_loss, None) if return_outputs else dummy_loss + + # --- Step 1: Generate strided blocks for n_samples --- + repeated_ids = input_ids.repeat_interleave(n_samples, dim=0) + + with torch.no_grad(): + full_sequences = self._generate_strided_blocks( + model, + repeated_ids, + num_blocks, + anchor_offset=anchor_offset, + ) + + # --- Step 2: Build strided mask for full generation --- + full_seq_len = full_sequences.shape[1] + model_dtype = next(model.parameters()).dtype + + # Free generation-phase memory before training forward pass + torch.cuda.empty_cache() + + attn_mask, pos_ids = self._build_strided_mask( + full_seq_len, + seq_len, + gen_len, + num_blocks, + B * n_samples, + device, + model_dtype, + anchor_offset=anchor_offset, + ) + + # --- Step 3: Forward pass through actor for log probs --- + # Memory optimization: process one sample at a time through the backbone + # to avoid B*N × S × H activation memory. For Llama-1B at S=3900, each + # sample's backbone forward takes ~8.7 GB with grad checkpointing. + # Processing B*N=4 at once would need ~35 GB → OOM. + # Instead, we accumulate per-token logprobs sample-by-sample. + gen_start = seq_len - 1 # shifted index where generated tokens start + compute_start = 0 if self.ebft_ce_coef > 0 else gen_start + BN = B * n_samples + + unwrapped_model = self.accelerator.unwrap_model(model) + # Navigate through PEFT wrapper to get backbone + lm_head + base_model = getattr(unwrapped_model, "model", unwrapped_model) + if hasattr(base_model, "model") and hasattr(base_model, "lm_head"): + backbone = base_model.model + lm_head = base_model.lm_head + else: + backbone = None + + per_token_logps_list = [] + shift_labels = full_sequences[:, 1:] + + if backbone is not None: + # Process one sample at a time: backbone → chunked lm_head → logprobs + # This keeps peak memory at ~1 sample's activations instead of B*N. + for s_idx in range(BN): + seq_s = full_sequences[s_idx : s_idx + 1] # (1, full_seq_len) + # Handle attention mask format (BlockMask vs dense 4D) + if isinstance(attn_mask, torch.Tensor) and attn_mask.dim() == 4: + mask_s = attn_mask[s_idx : s_idx + 1] + else: + mask_s = attn_mask # BlockMask broadcasts over batch + pos_s = pos_ids[s_idx : s_idx + 1] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + backbone_out = backbone( + seq_s, + attention_mask=mask_s, + position_ids=pos_s, + return_dict=True, + ) + hidden_s = backbone_out.last_hidden_state # (1, full_seq_len, H) + labels_s = shift_labels[s_idx : s_idx + 1] + + logps_s = torch.zeros( + 1, + hidden_s.shape[1] - 1, + device=device, + dtype=torch.float32, + ) + + region_h = hidden_s[:, compute_start:-1, :] + region_l = labels_s[:, compute_start:] + chunk_size = 256 + for i in range(0, region_h.shape[1], chunk_size): + h_chunk = region_h[:, i : i + chunk_size, :] + l_chunk = region_l[:, i : i + chunk_size] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_chunk = lm_head(h_chunk) + chunk_lp = F.log_softmax(logits_chunk.float(), dim=-1) + logps_s[ + :, compute_start + i : compute_start + i + h_chunk.shape[1] + ] = chunk_lp.gather(-1, l_chunk.unsqueeze(-1)).squeeze(-1) + del logits_chunk, chunk_lp + per_token_logps_list.append(logps_s) + del hidden_s, backbone_out, region_h + + per_token_logps = torch.cat(per_token_logps_list, dim=0) + else: + # Fallback: full forward (non-standard model architecture) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model( + full_sequences, + attention_mask=attn_mask, + position_ids=pos_ids, + return_dict=True, + ) + logits = outputs.logits + per_token_logps = torch.zeros( + logits.shape[0], + logits.shape[1] - 1, + device=device, + dtype=torch.float32, + ) + region_logits = logits[:, compute_start:-1, :] + region_labels = shift_labels[:, compute_start:] + chunk_size = 256 + for i in range(0, region_logits.shape[1], chunk_size): + chunk_logits = region_logits[:, i : i + chunk_size, :] + chunk_labels = region_labels[:, i : i + chunk_size] + chunk_lp = F.log_softmax(chunk_logits.float(), dim=-1) + per_token_logps[ + :, compute_start + i : compute_start + i + chunk_logits.shape[1] + ] = chunk_lp.gather(-1, chunk_labels.unsqueeze(-1)).squeeze(-1) + del logits, region_logits + + action_mask = torch.zeros( + per_token_logps.shape, dtype=torch.bool, device=device + ) + # Only mark actual generated tokens (not padding beyond num_blocks * gen_len) + gen_end = gen_start + num_blocks * gen_len + action_mask[:, gen_start:gen_end] = True + + # --- Step 4: Extract features and compute rewards --- + with torch.no_grad(): + block_rewards = self._compute_block_rewards( + full_sequences, + attn_mask, + pos_ids, + input_ids, + num_blocks, + B, + n_samples, + anchor_offset=anchor_offset, + ) + + del attn_mask, pos_ids + torch.cuda.empty_cache() + + # --- Step 5: Compute advantages --- + advantages_per_block = self._compute_advantages( + block_rewards, B, n_samples, num_blocks + ) + + token_advantages = advantages_per_block.repeat_interleave(gen_len, dim=1) + full_advantages = torch.zeros_like(per_token_logps) + # Only fill actual generated region (not padding beyond num_blocks * gen_len) + adv_len = token_advantages.shape[1] # = num_blocks * gen_len + full_advantages[:, gen_start : gen_start + adv_len] = token_advantages + + # --- Step 6: Compute loss --- + # RL loss: REINFORCE on generated tokens (needs grad through per_token_logps) + rl_loss_per_token = -per_token_logps * full_advantages.detach() + rl_loss = ( + rl_loss_per_token * action_mask.float() + ).sum() / action_mask.float().sum().clamp(min=1) + + # CE loss: For structured data, only compute on completion ground-truth tokens + # (labels != -100 in the original input). For unstructured data, compute on + # all non-action (prompt) tokens as before. + ce_loss = torch.tensor(0.0, device=device) + if self.ebft_ce_coef > 0: + if is_structured and "labels" in inputs: + labels = inputs["labels"].to(device) # (B, seq_len) + shifted_labels = labels[:, 1:] # (B, seq_len - 1) + ce_mask_base = shifted_labels != -100 # (B, seq_len - 1) + ce_mask_repeated = ce_mask_base.repeat_interleave(n_samples, dim=0) + ce_mask = torch.zeros( + per_token_logps.shape, dtype=torch.bool, device=device + ) + ce_mask[:, : ce_mask_repeated.shape[1]] = ce_mask_repeated + ce_mask[:, gen_start:] = False + else: + ce_mask = ~action_mask + ce_loss = ( + -per_token_logps * ce_mask.float() + ).sum() / ce_mask.float().sum().clamp(min=1) + + loss = self.ebft_rl_coef * rl_loss + self.ebft_ce_coef * ce_loss + + # --- Log metrics --- + if self.state.global_step % self.args.logging_steps == 0: + _alignment = getattr(self, "_last_alignment", 0.0) + _diversity = getattr(self, "_last_diversity", 0.0) + _cfm = getattr(self, "_last_cfm", 0.0) + _mean_reward = block_rewards.mean().item() + _adv_std = advantages_per_block.std().item() + + log_dict = { + "ebft/rl_loss": rl_loss.item(), + "ebft/ce_loss": ce_loss.item(), + "ebft/cfm_loss": _cfm, + "ebft/mean_reward": _mean_reward, + "ebft/alignment": _alignment, + "ebft/diversity": _diversity, + "ebft/num_blocks": num_blocks, + "ebft/advantages_std": _adv_std, + } + if is_structured: + log_dict["ebft/anchor_offset"] = anchor_offset + self.log(log_dict) + + # Human-readable summary with direction arrows: + # alignment (^ better) — cosine sim to GT features, range [-2, 2] + # diversity (v better) — pairwise sim penalty, lower = more diverse + # cfm_loss (v better) — ||E[phi(y_hat)] - phi(y)||^2 + # reward (^ better) — alignment - diversity + LOG.info( + f"step {self.state.global_step} | " + f"align {_alignment:+.3f} ^ | " + f"divers {_diversity:+.3f} v | " + f"cfm {_cfm:.3f} v | " + f"reward {_mean_reward:+.3f} ^ | " + f"adv_std {_adv_std:.3f} | " + f"blocks {num_blocks}" + ) + + return (loss, outputs) if return_outputs else loss + + @torch._dynamo.disable + @torch.no_grad() + def _generate_strided_blocks( + self, model, prompt_ids, num_blocks, anchor_offset=None + ): + """Generate tokens using strided block-parallel attention. + + Uses eager attention (dense 4D masks) during generation to avoid dynamo + recompilation — each generation step has a different sequence length. + The training forward pass (fixed size) uses flex_attention when available. + + Args: + anchor_offset: Position where anchors start. Defaults to context_length. + """ + B, seq_len = prompt_ids.shape + gen_len = self.ebft_generate_max_len + stride = self.ebft_stride + if anchor_offset is None: + anchor_offset = self.ebft_context_length + temperature = self.ebft_temperature + top_p = self.ebft_top_p + device = prompt_ids.device + model_dtype = next(model.parameters()).dtype + + full_sequence = prompt_ids.clone() + + # Force eager attention during generation to avoid dynamo recompiles from: + # 1. Variable sequence lengths per gen step → size-mismatch recompiles + # 2. no_grad vs grad toggling → grad_mode recompiles + # Both cause dynamo to hit the recompile limit → unfused fallback → OOM + unwrapped = self.accelerator.unwrap_model(model) + with override_attn_implementation(unwrapped, "eager"): + for generation_step in range(gen_len): + cur_len = full_sequence.shape[1] + + dense_mask, pos_ids = build_strided_dense_mask_and_positions( + full_sequence_length=cur_len, + prompt_length=seq_len, + context_length=anchor_offset, + generation_step=generation_step, + max_generation_length=gen_len, + stride=stride, + num_blocks=num_blocks, + device=device, + batch_size=B, + dtype=model_dtype, + ) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + output = model( + full_sequence, + attention_mask=dense_mask, + position_ids=pos_ids, + return_dict=True, + ) + all_logits = output.logits + + logit_positions = [] + for block_idx in range(num_blocks): + if generation_step == 0: + # Last token of the context window predicts the first rollout token + pos = anchor_offset + block_idx * stride - 1 + else: + pos = seq_len + (generation_step - 1) * num_blocks + block_idx + logit_positions.append(pos) + + position_indices = torch.tensor(logit_positions, device=device) + block_logits = all_logits.index_select(1, position_indices) + + if temperature > 0: + block_logits = block_logits / temperature + probs = torch.softmax(block_logits, dim=-1) + + if top_p < 1.0: + sorted_probs, sorted_idx = torch.sort( + probs, descending=True, dim=-1 + ) + cumulative = torch.cumsum(sorted_probs, dim=-1) + remove = cumulative > top_p + remove[..., 1:] = remove[..., :-1].clone() + remove[..., 0] = False + mask = torch.zeros_like(probs, dtype=torch.bool) + mask.scatter_(-1, sorted_idx, remove) + probs[mask] = 0 + probs = probs / probs.sum(dim=-1, keepdim=True) + + flat_probs = probs.view(-1, probs.shape[-1]) + sampled = torch.multinomial(flat_probs, 1).squeeze(-1) + sampled = sampled.view(B, num_blocks) + else: + sampled = torch.argmax(block_logits, dim=-1) + + full_sequence = torch.cat([full_sequence, sampled], dim=1) + + return full_sequence + + @torch._dynamo.disable + @torch.no_grad() + def _compute_block_rewards( + self, + full_sequences, + attn_mask, + pos_ids, + original_ids, + num_blocks, + batch_size, + n_samples, + anchor_offset=None, + ): + """Extract features and compute per-block rewards. Returns (B, N, NB). + + Args: + anchor_offset: Position where anchors start. For structured data this + is after the prompt; for unstructured it equals context_length. + """ + device = full_sequences.device + seq_len = original_ids.shape[1] + gen_len = self.ebft_generate_max_len + stride = self.ebft_stride + if anchor_offset is None: + anchor_offset = self.ebft_context_length + + # Run feature network on its device WITH the strided attention mask. + # Without the strided mask, generated tokens see tokens from other blocks + # via default causal attention, corrupting the feature representations. + fd = self._feature_device + fn_seqs = full_sequences.to(fd) + fn_pos = pos_ids.to(fd) + + # Determine which model to use for feature extraction + if self._share_feature_weights: + # Use actor's base weights with adapters disabled. + # Force eager attention to avoid grad_mode recompiles on the shared + # compiled flex_attention kernel (feature extraction is no_grad, + # training forward is with grad — each switch recompiles). + unwrapped_actor = self.accelerator.unwrap_model(self.model) + feat_model = unwrapped_actor + feature_ctx = unwrapped_actor.disable_adapter() + # Use SDPA (flash attention) instead of flex to avoid grad_mode recompiles + # on the shared compiled flex kernel. SDPA is fused (no score matrix + # materialization) and needs no compilation — ideal for no_grad feature extraction. + attn_ctx = override_attn_implementation(unwrapped_actor, "sdpa") + use_flex_for_features = False + else: + feat_model = self.feature_network + feature_ctx = contextlib.nullcontext() + attn_ctx = contextlib.nullcontext() + use_flex_for_features = self._feature_use_flex + + # Build strided mask — flex block mask if available, else dense 4D + if use_flex_for_features: + fn_attn_mask = create_strided_block_mask( + prompt_length=seq_len, + context_length=anchor_offset, + max_generation_length=gen_len, + stride=stride, + num_blocks=num_blocks, + full_sequence_length=full_sequences.shape[1], + batch_size=full_sequences.shape[0], + num_heads=feat_model.config.num_attention_heads, + device=fd, + ) + else: + fn_attn_mask, _ = build_strided_dense_mask_and_positions( + full_sequence_length=full_sequences.shape[1], + prompt_length=seq_len, + context_length=anchor_offset, + generation_step=gen_len, + max_generation_length=gen_len, + stride=stride, + num_blocks=num_blocks, + device=fd, + batch_size=full_sequences.shape[0], + dtype=torch.bfloat16, + ) + + with ( + feature_ctx, + attn_ctx, + torch.autocast(device_type="cuda", dtype=torch.bfloat16), + ): + was_training = feat_model.training + feat_model.eval() + fn_outputs = feat_model( + fn_seqs, + attention_mask=fn_attn_mask, + position_ids=fn_pos, + output_hidden_states=True, + return_dict=True, + ) + if was_training: + feat_model.train() + hidden_states_cpu = [ + fn_outputs.hidden_states[idx].to(device) + for idx in self.feature_layer_indices + ] + del fn_outputs, fn_seqs, fn_pos, fn_attn_mask + + # Normalize each layer's hidden states separately (like the reference critic), + # then concatenate. This prevents one dominant layer from suppressing others. + normalized_layers = [F.normalize(h, p=2, dim=-1) for h in hidden_states_cpu] + features = torch.cat(normalized_layers, dim=-1).to(device) + del hidden_states_cpu, normalized_layers + + # Ground-truth features start from anchor_offset (not ctx_len) so they + # align with where anchors are actually placed. + gt_features = features[:, anchor_offset:seq_len, :] + # Only take actual generated tokens (exclude padding beyond num_blocks * gen_len) + gen_features = features[:, seq_len : seq_len + num_blocks * gen_len, :] + + gt_block_features = gt_features.unfold(1, gen_len, stride).permute(0, 1, 3, 2) + gen_block_features = gen_features.reshape( + batch_size * n_samples, gen_len, num_blocks, -1 + ).transpose(1, 2) + + if self.ebft_embed_method == "mean_pooling": + gt_emb = gt_block_features.mean(dim=2) + gen_emb = gen_block_features.mean(dim=2) + else: # last_token + gt_emb = gt_block_features[:, :, -1, :] + gen_emb = gen_block_features[:, :, -1, :] + + gt_emb = gt_emb.view(batch_size, n_samples, num_blocks, -1) + gen_emb = gen_emb.view(batch_size, n_samples, num_blocks, -1) + + if self.ebft_use_whitening: + whitened_gen, whitened_gt = [], [] + for b in range(batch_size): + for nb in range(num_blocks): + w_gen, w_gt = whiten_embeddings_batched( + gen_emb[b, :, nb, :], + gt_emb[b, :, nb, :], + ) + whitened_gen.append(w_gen) + whitened_gt.append(w_gt) + gen_emb = ( + torch.stack(whitened_gen) + .view(batch_size, num_blocks, n_samples, -1) + .transpose(1, 2) + ) + gt_emb = ( + torch.stack(whitened_gt) + .view(batch_size, num_blocks, n_samples, -1) + .transpose(1, 2) + ) + + alignment = F.cosine_similarity(gen_emb, gt_emb, dim=-1) + + # Batched diversity: reshape to avoid per-block Python loop + diversity = torch.zeros_like(alignment) + if n_samples > 1: + # (B, N, NB, D) → (B*NB, N, D) for a single batched bmm + gen_for_div = gen_emb.permute(0, 2, 1, 3).reshape( + batch_size * num_blocks, n_samples, -1 + ) + sims = torch.bmm(gen_for_div, gen_for_div.transpose(1, 2)) # (B*NB, N, N) + eye = torch.eye(n_samples, device=device, dtype=torch.bool) + sims = sims.masked_fill(eye.unsqueeze(0), 0.0) + div_flat = sims.sum(dim=-1) / (n_samples - 1) # (B*NB, N) + diversity = div_flat.view(batch_size, num_blocks, n_samples).permute( + 0, 2, 1 + ) # (B, N, NB) + + # Scale by 2 per paper equation (7): + # r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'}) + alignment = alignment * 2 + diversity = diversity * 2 + + # Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2) + # Mean generated embedding per prompt, squared distance to GT + mean_gen_emb = gen_emb.mean(dim=1, keepdim=True) # (B, 1, NB, D) + gt_for_cfm = gt_emb[:, 0:1, :, :] # (B, 1, NB, D) — one GT per prompt + cfm_loss = ((mean_gen_emb - gt_for_cfm) ** 2).sum(dim=-1).mean() + + # Store for logging + self._last_alignment = alignment.mean().item() + self._last_diversity = diversity.mean().item() + self._last_cfm = cfm_loss.item() + + return ( + self.ebft_alignment_coef * alignment - self.ebft_diversity_coef * diversity + ) + + def _compute_advantages(self, rewards, batch_size, n_samples, num_blocks): + """Compute RLOO advantages. rewards: (B, N, NB) → (B*N, NB).""" + if self.ebft_advantage_estimator == "rloo" and n_samples > 1: + total = rewards.sum(dim=1, keepdim=True) + baseline = (total - rewards) / (n_samples - 1) + advantages = rewards - baseline + elif self.ebft_advantage_estimator == "group_norm" and n_samples > 1: + mean = rewards.mean(dim=1, keepdim=True) + std = rewards.std(dim=1, keepdim=True) + 1e-8 + advantages = (rewards - mean) / std + else: + advantages = rewards + return advantages.view(batch_size * n_samples, num_blocks) diff --git a/src/axolotl/core/trainers/ebft/trainer.py b/src/axolotl/core/trainers/ebft/trainer.py new file mode 100644 index 000000000..1c27fa91f --- /dev/null +++ b/src/axolotl/core/trainers/ebft/trainer.py @@ -0,0 +1,531 @@ +""" +EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer. + +Extends AxolotlGRPOTrainer by plugging feature-matching rewards into +the standard GRPO reward function interface. + +Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models" + (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248 +""" + +import contextlib +import copy +from typing import TYPE_CHECKING, Any + +import torch +from datasets import Dataset, IterableDataset +from peft import PeftModel +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + +from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig +from axolotl.core.trainers.ebft.rewards import ( + apply_embed_method, + extract_hidden_states, + get_alignment_rewards, + get_diversity_rewards, + whiten_embeddings_batched, +) +from axolotl.core.trainers.grpo.trainer import ( + AxolotlAsyncGRPOTrainer, + AxolotlGRPOTrainer, +) +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from collections import defaultdict + + from accelerate import Accelerator + from trl.generation.vllm_generation import VLLMGeneration + +LOG = get_logger(__name__) + + +class EBFTMixin: + """ + Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer. + + Provides: + - Frozen feature network setup (shared weights for PEFT, deepcopy otherwise) + - _feature_matching_reward() callable for GRPO reward function interface + - _sequential_rollout() for multi-turn conversations + """ + + # Type stubs for attributes provided by the composed GRPOTrainer base class. + # These are not defined here but accessed via cooperative multiple inheritance. + if TYPE_CHECKING: + accelerator: Accelerator + model: PreTrainedModel + args: AxolotlEBFTConfig + processing_class: PreTrainedTokenizerBase + num_generations: int + vllm_generation: VLLMGeneration + _metrics: defaultdict + + _tag_names = ["trl", "ebft", "axolotl"] + + def __init__( + self, + model: str | PreTrainedModel, + args: AxolotlEBFTConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset + | IterableDataset + | dict[str, Dataset | IterableDataset] + | None = None, + processing_class: PreTrainedTokenizerBase | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[ + torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None + ] = (None, None), + peft_config: Any | None = None, + ): + # Pass our feature-matching reward function to GRPOTrainer + # It will be called with (prompts, completions, **kwargs) where + # kwargs includes all extra dataset fields like "ground_truth" + super().__init__( # type: ignore[call-arg] + model=model, + reward_funcs=[self._feature_matching_reward], + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + assert args is not None + + # --- Feature network setup --- + unwrapped = self.accelerator.unwrap_model(self.model) + # Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping + self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr( + unwrapped, "disable_adapter" + ) + + if self._share_feature_weights: + # Share weights: use actor's base model with adapters disabled. + # Saves a full model copy (~8 GB for 4B model). + self.feature_network = None + param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9 + LOG.info( + f"EBFT feature network shares actor weights (PEFT disable_adapter). " + f"Saving ~{param_gb:.1f} GB" + ) + else: + LOG.info("Creating frozen feature network for EBFT (deepcopy)...") + self.feature_network = copy.deepcopy(unwrapped) + for param in self.feature_network.parameters(): + param.requires_grad = False + self.feature_network.eval() + + # Compute layer indices from fractional depths + # Handle VLM models where num_hidden_layers is on text_config + config = unwrapped.config + if hasattr(config, "text_config") and hasattr( + config.text_config, "num_hidden_layers" + ): + config = config.text_config + num_layers = config.num_hidden_layers + self.feature_layer_indices = [ + int(frac * num_layers) for frac in args.ebft_feature_layers + ] + LOG.info( + f"EBFT feature extraction from layers {self.feature_layer_indices} " + f"(of {num_layers} total), embed_method={args.ebft_embed_method}" + ) + if args.ebft_adaptive_max_tokens: + LOG.info( + f"EBFT adaptive max_tokens enabled " + f"(gt_length_multiplier={args.ebft_gt_length_multiplier})" + ) + + _adaptive_max_lock = None # initialized lazily + + def _generate_only(self, inputs, rank0_only=False): + """Override to set per-batch max_tokens based on ground-truth length. + + Uses a lock to prevent race conditions in async mode where concurrent + BG threads could interleave mutations of max_completion_length. + """ + import threading + + args = self.args + if ( + args.ebft_adaptive_max_tokens + and hasattr(self, "vllm_generation") + and inputs + ): + gt_texts = [ + x.get("ground_truth", "") for x in inputs if x.get("ground_truth") + ] + if gt_texts: + gt_token_counts = [ + len(self.processing_class.encode(gt, add_special_tokens=False)) + for gt in gt_texts + ] + multiplier = args.ebft_gt_length_multiplier + max_completion = self.vllm_generation.max_completion_length + adaptive_max = max( + min(int(c * multiplier), max_completion) for c in gt_token_counts + ) + adaptive_max = max(adaptive_max, 64) + + if self._adaptive_max_lock is None: + self._adaptive_max_lock = threading.Lock() + with self._adaptive_max_lock: + original = self.vllm_generation.max_completion_length + self.vllm_generation.max_completion_length = adaptive_max + try: + return super()._generate_only(inputs, rank0_only) + finally: + self.vllm_generation.max_completion_length = original + + return super()._generate_only(inputs, rank0_only) + + @torch.no_grad() + def _feature_matching_reward( + self, + prompts: list, + completions: list, + ground_truth: list[str] | None = None, + remaining_turns: list | None = None, + **kwargs, + ) -> list[float]: + """ + Compute feature-matching rewards for generated completions. + + This is called by GRPOTrainer's _generate_and_score_completions() + as a standard reward function. The `ground_truth` field comes from + the dataset via reward_kwargs. + + For multi-turn conversations, `remaining_turns` contains the subsequent + user/assistant turn pairs. When present, we do sequential rollouts: + generate each assistant turn conditioned on history + previous generations, + then compute feature-matching rewards on the full generated conversation. + + Args: + prompts: List of prompt strings/messages + completions: List of generated completion strings + ground_truth: List of reference completion strings (from dataset) + remaining_turns: List of remaining conversation turns after the + first assistant turn (for multi-turn rollouts) + + Returns: + List of scalar rewards, one per completion + """ + if ground_truth is None: + LOG.warning("No ground_truth field in dataset — using zero rewards") + return [0.0] * len(prompts) + + device = self.accelerator.device + args = self.args + num_gens = self.num_generations + + # --- Multi-turn sequential rollout --- + # If remaining_turns is provided, generate subsequent assistant turns + # by calling vLLM for each turn, building up the full conversation. + if remaining_turns is not None and hasattr(self, "vllm_generation"): + completions = self._sequential_rollout( + prompts, completions, remaining_turns, num_gens + ) + + # --- Tokenize generated sequences: prompt + completion --- + gen_texts = [] + gen_prompt_texts = [] + for p, c in zip(prompts, completions, strict=True): + if isinstance(p, list): + prompt_text = self.processing_class.apply_chat_template( + p, tokenize=False, add_generation_prompt=True + ) + else: + prompt_text = p + if isinstance(c, list): + comp_text = c[0].get("content", "") if c else "" + else: + comp_text = c + gen_texts.append(prompt_text + comp_text) + gen_prompt_texts.append(prompt_text) + + gen_encoded = self.processing_class( + text=gen_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=getattr(self.args, "max_length", None) + or getattr(self.args, "max_seq_length", None) + or 2048, + add_special_tokens=False, + ) + gen_ids = gen_encoded["input_ids"].to(device) + gen_mask = gen_encoded["attention_mask"].to(device) + + # Compute prompt lengths for completion_mean pooling + gen_prompt_lengths = torch.tensor( + [ + len(self.processing_class.encode(pt, add_special_tokens=False)) + for pt in gen_prompt_texts + ], + device=device, + ) + + # --- Tokenize ground-truth sequences: prompt + ground_truth --- + # For multi-turn (remaining_turns present), render the full GT conversation + # through the chat template to preserve role markers between turns. + gt_texts = [] + gt_prompt_texts = [] + for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)): + if i % num_gens != 0: + continue # Only need one GT per prompt group + if isinstance(p, list): + prompt_text = self.processing_class.apply_chat_template( + p, tokenize=False, add_generation_prompt=True + ) + # Multi-turn: build full GT conversation with remaining turns + if remaining_turns is not None: + prompt_idx = i // num_gens + turns = ( + remaining_turns[prompt_idx] + if prompt_idx < len(remaining_turns) + else [] + ) + if turns: + gt_conv = list(p) + [{"role": "assistant", "content": gt}] + gt_conv.extend(turns) + full_gt_text = self.processing_class.apply_chat_template( + gt_conv, tokenize=False, add_generation_prompt=False + ) + gt_texts.append(full_gt_text) + gt_prompt_texts.append(prompt_text) + continue + else: + prompt_text = p + gt_texts.append(prompt_text + gt) + gt_prompt_texts.append(prompt_text) + + gt_encoded = self.processing_class( + text=gt_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=getattr(self.args, "max_length", None) + or getattr(self.args, "max_seq_length", None) + or 2048, + add_special_tokens=False, + ) + gt_ids = gt_encoded["input_ids"].to(device) + gt_mask = gt_encoded["attention_mask"].to(device) + + gt_prompt_lengths = torch.tensor( + [ + len(self.processing_class.encode(pt, add_special_tokens=False)) + for pt in gt_prompt_texts + ], + device=device, + ) + + # --- Extract features from frozen feature network --- + # INVARIANT: disable_adapter() yields the unmodified base weights because + # _sync_peft_weights_no_merge and _sync_lora_adapter never call + # merge_adapter() — they compute merged weights as new tensors or save + # the adapter to filesystem. Base weights are never modified in-place. + if self._share_feature_weights: + unwrapped = self.accelerator.unwrap_model(self.model) + feature_ctx = unwrapped.disable_adapter() + else: + unwrapped = self.feature_network + feature_ctx = contextlib.nullcontext() + + with feature_ctx: + was_training = unwrapped.training + unwrapped.eval() + gen_hidden = extract_hidden_states( + unwrapped, gen_ids, gen_mask, self.feature_layer_indices + ) + gt_hidden = extract_hidden_states( + unwrapped, gt_ids, gt_mask, self.feature_layer_indices + ) + if was_training: + unwrapped.train() + + # --- Pool to sequence-level embeddings --- + gen_emb = apply_embed_method( + gen_hidden, + args.ebft_embed_method, + gen_mask, + prompt_lengths=gen_prompt_lengths, + ) + gt_emb = apply_embed_method( + gt_hidden, + args.ebft_embed_method, + gt_mask, + prompt_lengths=gt_prompt_lengths, + ) + + # --- Optional whitening --- + batch_size = gen_emb.shape[0] + if args.ebft_use_whitening and batch_size > 1: + num_prompts = batch_size // num_gens + gen_reshaped = gen_emb.view(num_prompts, num_gens, -1) + whitened_gen_list = [] + whitened_gt_list = [] + for i in range(num_prompts): + w_gen, w_gt = whiten_embeddings_batched( + gen_reshaped[i], gt_emb[i : i + 1] + ) + whitened_gen_list.append(w_gen) + whitened_gt_list.append(w_gt) + gen_emb = torch.cat(whitened_gen_list, dim=0) + gt_emb = torch.cat(whitened_gt_list, dim=0) + else: + gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1) + gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1) + + # Repeat gt_emb: each GT repeated num_generations times + gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0) + + # --- Compute rewards --- + alignment = get_alignment_rewards(gen_emb, gt_emb_expanded) + diversity = get_diversity_rewards(gen_emb, num_gens) + + # Scale by 2 per paper equation (7): + # r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'}) + alignment = alignment * 2 + diversity = diversity * 2 + + rewards = ( + args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity + ) + + # Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2) + gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1]) + mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D) + cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean() + + # Log feature-matching metrics to console and wandb + _align = alignment.mean().item() + _divers = diversity.mean().item() + _reward = rewards.mean().item() + _cfm = cfm_loss.item() + + LOG.info( + f"ebft reward | " + f"align {_align:+.3f} ^ | " + f"divers {_divers:+.3f} v | " + f"cfm {_cfm:.3f} v | " + f"reward {_reward:+.3f} ^" + ) + + # Log to wandb via trainer's _metrics (picked up by GRPO's logging) + mode = "train" if self.model.training else "eval" + if hasattr(self, "_metrics"): + self._metrics[mode]["ebft/alignment"].append(_align) + self._metrics[mode]["ebft/diversity"].append(_divers) + self._metrics[mode]["ebft/cfm_loss"].append(_cfm) + self._metrics[mode]["ebft/reward"].append(_reward) + + return rewards.cpu().tolist() + + @torch.no_grad() + def _sequential_rollout( + self, + prompts: list, + first_completions: list, + remaining_turns: list, + num_gens: int, + ) -> list: + """ + Extend single-turn completions into multi-turn conversations. + + For each prompt group, takes the first generated assistant turn and + sequentially generates subsequent assistant turns by calling vLLM, + building up a full multi-turn conversation. + + Args: + prompts: List of prompt message lists (repeated num_gens times) + first_completions: List of generated first-turn completions + remaining_turns: List of remaining turn pairs after first assistant turn. + Each element is a list of dicts: [{"role": "user", "content": "..."}, + {"role": "assistant", "content": "...GT..."}] + num_gens: Number of generations per prompt + + Returns: + Extended completions incorporating all generated turns + """ + vllm_client = self.vllm_generation.vllm_client + max_tokens = getattr(self.args, "max_completion_length", 256) + temperature = getattr(self.args, "temperature", 0.7) + gen_kwargs = getattr(self.args, "generation_kwargs", None) or {} + + extended_completions = [] + + for idx in range(len(prompts)): + prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else [] + first_comp = first_completions[idx] + + # Extract first completion text + if isinstance(first_comp, list): + first_text = first_comp[0].get("content", "") if first_comp else "" + else: + first_text = first_comp + + # Get remaining turns for this prompt (same for all num_gens copies) + prompt_idx = idx // num_gens + turns = ( + remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else [] + ) + + if not turns: + extended_completions.append(first_text) + continue + + # Build conversation with generated first turn + conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}] + + # Generate subsequent turns + for turn in turns: + if turn["role"] == "user": + conv.append(turn) + elif turn["role"] == "assistant": + try: + result = vllm_client.chat( + messages=[conv], + n=1, + max_tokens=max_tokens, + temperature=temperature, + generation_kwargs=gen_kwargs, + ) + gen_ids = result.get("completion_ids", [[]])[0] + gen_text = self.processing_class.decode( + gen_ids, skip_special_tokens=True + ) + except Exception as e: + LOG.warning(f"Multi-turn rollout generation failed: {e}") + gen_text = "" + + conv.append({"role": "assistant", "content": gen_text}) + + # Render full conversation through chat template, then extract + # everything after the original prompt as the "completion" text. + # This preserves role markers and formatting between turns. + full_rendered = self.processing_class.apply_chat_template( + conv, tokenize=False, add_generation_prompt=False + ) + prompt_rendered = self.processing_class.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + completion_text = full_rendered[len(prompt_rendered) :] + extended_completions.append(completion_text) + + return extended_completions + + +class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer): + """EBFT trainer using synchronous GRPO (standard vLLM generation).""" + + pass + + +class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer): + """EBFT trainer using async GRPO (prefetches next batch during training).""" + + pass diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index acfd02909..2b8cda6d8 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer): """ def __init__(self, *args, **kwargs): - # When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration. - # The communicator is not needed because weight sync happens via filesystem + HTTP, - # and it fails when vLLM and a trainer rank share the same CUDA device. + # Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only + # merged weight sync. NCCL is only needed for the standard update_named_param + # path which broadcasts tensors through the communicator. training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None) - if training_args is not None and getattr( - training_args, "vllm_lora_sync", False - ): + _skip_nccl = False + if training_args is not None: + if getattr(training_args, "vllm_lora_sync", False): + _skip_nccl = True # LoRA sync uses filesystem + HTTP + elif getattr(training_args, "async_prefetch", False): + # Skip NCCL at init to avoid DDP param count mismatch in multi-GPU. + # init_communicator allocates device tensors on rank 0 only, which + # causes DDP to see different param counts across ranks. + # The communicator is initialized lazily on first weight sync instead. + _skip_nccl = True + if _skip_nccl: from trl.generation.vllm_generation import VLLMGeneration _orig_init_vllm = VLLMGeneration._init_vllm @@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer): VLLMGeneration._init_vllm = _init_vllm_no_communicator - super().__init__(*args, **kwargs) + try: + super().__init__(*args, **kwargs) + finally: + # Restore original _init_vllm so other trainers aren't affected + if _skip_nccl: + VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined] # FP8 models: zero out the pad token embedding so that padding # positions have zero hidden states throughout the network. @@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer): self._executor = None def _submit_generation(self): - """Submit the next background generation job.""" + """Submit the next background generation job. + + With multi-process (DDP/FSDP), only rank 0 generates to avoid + cross-rank NCCL collectives from background threads. Non-rank-0 + processes enqueue a sentinel ``None`` that is replaced by a + broadcast in ``_prepare_inputs_legacy_async``. + """ + rank0_only = self.accelerator.num_processes > 1 + if rank0_only and not self.accelerator.is_main_process: + # Non-rank-0: nothing to generate; enqueue a resolved None future + f: concurrent.futures.Future = concurrent.futures.Future() + f.set_result(None) + self._async_queue.put(f) + return batch = next(self._prompt_iter) - future = self._executor.submit(self._generate_only, batch) + future = self._executor.submit(self._generate_only, batch, rank0_only) self._async_queue.put(future) + # ------------------------------------------------------------------ + # Broadcast rollout (legacy async, multi-process) + # ------------------------------------------------------------------ + + def _broadcast_rollout(self, rollout: dict | None) -> dict: + """Broadcast a rank0-only rollout dict to all ranks (main thread). + + Rank 0 has the full rollout dict from ``_generate_only``; other ranks + have ``None``. After broadcast, tensors are moved to each rank's + local device. + """ + import torch.distributed as dist + + obj_list = [rollout if self.accelerator.is_main_process else None] + dist.broadcast_object_list(obj_list, src=0) + rollout = obj_list[0] + assert rollout is not None, "broadcast_object_list failed to deliver rollout" + + # Move tensors to local device (broadcast deserializes to CPU) + device = self.accelerator.device + for key, val in rollout.items(): + if isinstance(val, torch.Tensor) and val.device != device: + rollout[key] = val.to(device) + + return rollout + # ------------------------------------------------------------------ # Weight sync # ------------------------------------------------------------------ @@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer): for Float8), and also safe for concurrent use since it never modifies base weights in-place. """ - model = self.vllm_generation.model accelerator = self.vllm_generation.accelerator - vllm_client = self.vllm_generation.vllm_client - fix_name = self.vllm_generation._fix_param_name_to_vllm - if not (self.vllm_generation.mode == "server" and accelerator.is_main_process): return + # In multi-GPU async mode, we skip NCCL communicator init to avoid + # DDP param count mismatch and NCCL device conflicts. Weight sync + # uses the HTTP-only fallback in batch_update_named_params instead. + + model = self.vllm_generation.model + vllm_client = self.vllm_generation.vllm_client + fix_name = self.vllm_generation._fix_param_name_to_vllm + # Build lookup: module_path -> (A, B, scaling) for all active LoRA layers lora_info = {} for mod_name, module in model.base_model.model.named_modules(): @@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer): weight_name = pname.replace(".weight_scale_inv", ".weight") scale_inv_lookup[weight_name] = pparam.data - # Iterate all parameters, computing merged weights for LoRA layers. - # Skip LoRA-specific params and FP8 scale params (scales will be - # recomputed by vLLM when it receives the merged bf16 weight). + # Only sync parameters that have LoRA modifications — skip unchanged + # base weights to avoid OOM on the vLLM GPU from allocating the entire + # model's worth of NCCL receive buffers. params_to_sync = [] + compute_dtype = torch.bfloat16 for name, param in model.named_parameters(): vllm_name = name.removeprefix("base_model.model.").replace( ".base_layer", "" @@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer): continue if "original_module" in vllm_name: continue - # Skip FP8 quantization scale parameters - they are recomputed - # on the vLLM side when we update the weight itself if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name: continue + if not vllm_name.endswith(".weight"): + continue + # fix_name strips modules_to_save.default. prefix + raw_mod_path = vllm_name[: -len(".weight")] vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."]) + mod_path = vllm_name[: -len(".weight")] + + # Sync weights that have LoRA adapters OR are modules_to_save + is_lora = mod_path in lora_info + is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix + if not is_lora and not is_modules_to_save: + continue data = param.data - compute_dtype = torch.bfloat16 - if vllm_name.endswith(".weight"): - # Dequantize FP8 weights before merging - if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup: - scale_inv = scale_inv_lookup[name] - # Block dequantization: weight * scale_inv (with broadcasting) - fp8_bf16 = data.to(compute_dtype) - if scale_inv.dim() == 2 and fp8_bf16.dim() == 2: - # Block-quantized: scale_inv shape (rows/block, cols/block) - sr, sc = scale_inv.shape - br = fp8_bf16.shape[0] // sr # block height - bc = fp8_bf16.shape[1] // sc # block width - # Reshape → multiply by block scale → reshape back - data = ( - fp8_bf16.reshape(sr, br, sc, bc) - * scale_inv[:, None, :, None].to(compute_dtype) - ).reshape(fp8_bf16.shape) - elif scale_inv.dim() <= 1: - # Per-tensor or per-channel scale - data = fp8_bf16 * scale_inv.to(compute_dtype) - else: - data = fp8_bf16 - elif data.dtype == torch.float8_e4m3fn: - # FP8 but no scale found - just cast (lossy) - data = data.to(compute_dtype) + # Dequantize FP8 weights before merging + if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup: + scale_inv = scale_inv_lookup[name] + fp8_bf16 = data.to(compute_dtype) + if scale_inv.dim() == 2 and fp8_bf16.dim() == 2: + sr, sc = scale_inv.shape + br = fp8_bf16.shape[0] // sr + bc = fp8_bf16.shape[1] // sc + data = ( + fp8_bf16.reshape(sr, br, sc, bc) + * scale_inv[:, None, :, None].to(compute_dtype) + ).reshape(fp8_bf16.shape) + elif scale_inv.dim() <= 1: + data = fp8_bf16 * scale_inv.to(compute_dtype) + else: + data = fp8_bf16 + elif data.dtype == torch.float8_e4m3fn: + data = data.to(compute_dtype) - mod_path = vllm_name[: -len(".weight")] - if mod_path in lora_info: - A, B, s = lora_info[mod_path] - merged = data.to(compute_dtype) + s * ( - B.to(compute_dtype) @ A.to(compute_dtype) - ) - data = merged + if is_lora: + A, B, s = lora_info[mod_path] + merged = data.to(compute_dtype) + s * ( + B.to(compute_dtype) @ A.to(compute_dtype) + ) + params_to_sync.append((vllm_name, merged)) + else: + # modules_to_save: send raw weight (no LoRA merge needed) + params_to_sync.append((vllm_name, data.to(compute_dtype))) - params_to_sync.append((vllm_name, data)) - - # Batch sync all params in one HTTP+NCCL call (vs individual calls) + # Batch sync only LoRA-modified params via HTTP+NCCL if params_to_sync: + sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6 + logger.info( + f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)" + ) vllm_client.batch_update_named_params(params_to_sync) # Reset prefix cache after weight update @@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer): vllm_client = self.vllm_generation.vllm_client url = f"{vllm_client.base_url}/set_lora_adapter/" + sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300 response = requests.post( url, json={ @@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer): "lora_int_id": self._lora_sync_version, "lora_path": adapter_path, }, - timeout=30, + timeout=sync_timeout, ) if response.status_code != 200: logger.warning( @@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer): step = self.state.global_step interval = self.args.vllm_sync_interval if step != self._last_synced_step and step % interval == 0: + if step == 0: + logger.info("Skipping vLLM weight sync at step 0 (no training yet)") + self._last_synced_step = step + return if getattr(self.args, "vllm_lora_sync", False): - if step == 0: - logger.info("Skipping LoRA sync at step 0 (no training yet)") - self._last_synced_step = step - return # Native LoRA sync: save adapter to filesystem, vLLM loads it directly self._sync_lora_adapter() else: @@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer): # Background-thread generation (no scoring) # ------------------------------------------------------------------ - def _generate_single_turn(self, prompts, **kwargs): + def _generate_single_turn(self, prompts, *args, **kwargs): """Override to prevent weight sync from background thread and to use no-merge sync for PEFT models (FP8 models can't merge_adapter).""" is_bg = threading.current_thread() is not threading.main_thread() @@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer): self._patched_sync_weights = True try: - return super()._generate_single_turn(prompts, **kwargs) + return super()._generate_single_turn(prompts, *args, **kwargs) finally: if saved_step is not None: self._last_loaded_step = saved_step @@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer): output = vg.vllm_client.chat( messages=unique_prompts, **sampling_params, - chat_template_kwargs=vg.chat_template_kwargs, - tools=vg.tools, - chat_template=vg.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + chat_template=getattr(self, "chat_template", None), ) else: output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params) @@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer): logps_diff = per_token_logps_diff is_ratio = torch.exp(logps_diff) + is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333) if is_mode in ("sequence_truncate", "token_truncate"): - is_ratio = torch.clamp(is_ratio, max=is_cap) + is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap) elif is_mode in ("sequence_mask", "token_mask"): is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) + is_ratio = is_ratio.clamp(min=is_floor) data["importance_sampling_ratio"] = is_ratio # --- Collect rewards (launched before logprobs, should be done) --- @@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer): seq_is = is_mode in ("sequence_mask", "sequence_truncate") logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff is_ratio = torch.exp(logps_diff) + # Symmetric floor clamp (matches non-streaming path at line ~1651) + is_floor = 1.0 / is_cap if is_mode in ("sequence_truncate", "token_truncate"): - is_ratio = torch.clamp(is_ratio, max=is_cap) + is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap) elif is_mode in ("sequence_mask", "token_mask"): is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) + is_ratio = is_ratio.clamp(min=is_floor) if "importance_sampling_ratio" not in data: total = len(data["prompt_ids"]) shape = (total, 1) if seq_is else (total, is_ratio.size(1)) @@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer): rollout = future.result() self._submit_generation() + # With multi-process, only rank 0 generated. Broadcast to all ranks. + if self.accelerator.num_processes > 1: + rollout = self._broadcast_rollout(rollout) + if self.args.streaming_partial_batch: micro_batches = self._score_streaming(rollout) else: diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py index 5f5ff3400..61464ee7d 100644 --- a/src/axolotl/integrations/diffusion/callbacks.py +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback): logger.info("=" * 60) if self.trainer.axolotl_cfg.use_wandb: - if wandb.run is not None: - wandb.log( + if wandb.run is not None: # type: ignore[attr-defined] + wandb.log( # type: ignore[attr-defined] { - "generated_samples": wandb.Table( + "generated_samples": wandb.Table( # type: ignore[attr-defined] columns=[ "step", "original", diff --git a/src/axolotl/monkeypatch/trainer/trl_vllm.py b/src/axolotl/monkeypatch/trainer/trl_vllm.py index a3296df61..e3f57ccf5 100644 --- a/src/axolotl/monkeypatch/trainer/trl_vllm.py +++ b/src/axolotl/monkeypatch/trainer/trl_vllm.py @@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__) def _batch_update_named_params( self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None ): - """Batched weight sync — sends param metadata via HTTP, tensors via NCCL.""" - from transformers import is_torch_xpu_available + """Batched weight sync — uses NCCL if communicator available, HTTP otherwise.""" + has_communicator = getattr(self, "communicator", None) is not None - if chunk_size is None: - chunks = [params] - else: - chunks = [] - current_chunk: list[tuple[str, torch.Tensor]] = [] - current_elements = 0 - for name, weights in params: - n_elem = weights.numel() - if current_chunk and current_elements + n_elem > chunk_size: - chunks.append(current_chunk) - current_chunk = [] - current_elements = 0 - current_chunk.append((name, weights)) - current_elements += n_elem - if current_chunk: - chunks.append(current_chunk) + if has_communicator: + # Fast path: metadata via HTTP, tensors via NCCL + from transformers import is_torch_xpu_available - for chunk in chunks: - param_metadata = [ - {"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)} - for name, weights in chunk - ] - url = f"{self.base_url}/batch_update_named_params/" - response = self.session.post(url, json={"params": param_metadata}) - if response.status_code != 200: - raise Exception(f"Request failed: {response.status_code}, {response.text}") - - for _name, weights in chunk: - if is_torch_xpu_available(): - self.communicator.broadcast(weights, root=self.rank) - else: - self.communicator.broadcast(weights, src=self.rank) - - if is_torch_xpu_available(): - self.communicator.barrier() + if chunk_size is None: + chunks = [params] else: - self.communicator.group.barrier() + chunks = [] + current_chunk: list[tuple[str, torch.Tensor]] = [] + current_elements = 0 + for name, weights in params: + n_elem = weights.numel() + if current_chunk and current_elements + n_elem > chunk_size: + chunks.append(current_chunk) + current_chunk = [] + current_elements = 0 + current_chunk.append((name, weights)) + current_elements += n_elem + if current_chunk: + chunks.append(current_chunk) + + for chunk in chunks: + param_metadata = [ + { + "name": name, + "dtype": str(weights.dtype), + "shape": list(weights.shape), + } + for name, weights in chunk + ] + url = f"{self.base_url}/batch_update_named_params/" + response = self.session.post( + url, json={"params": param_metadata}, timeout=120 + ) + if response.status_code != 200: + raise Exception( + f"Request failed: {response.status_code}, {response.text}" + ) + + for _name, weights in chunk: + if is_torch_xpu_available(): + self.communicator.broadcast(weights, root=self.rank) + else: + self.communicator.broadcast(weights, src=self.rank) + + if is_torch_xpu_available(): + self.communicator.barrier() + else: + self.communicator.group.barrier() + else: + # HTTP-only path: encode tensor data in request body (no NCCL needed). + # Batch by byte size to avoid huge HTTP payloads. + MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB + HTTP_TIMEOUT = 120 # seconds per request + + payload: list[dict] = [] + payload_bytes = 0 + url = f"{self.base_url}/http_update_weights/" + + def _flush(p: list[dict]) -> None: + if not p: + return + response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT) + if response.status_code != 200: + raise Exception( + f"Request failed: {response.status_code}, {response.text}" + ) + + from axolotl.utils.weight_serde import encode_for_http + + for name, weights in params: + entry = encode_for_http(name, weights) + entry_bytes = weights.nelement() * weights.element_size() + + # Flush current batch if adding this entry would exceed limit + if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST: + _flush(payload) + payload = [] + payload_bytes = 0 + + payload.append(entry) + payload_bytes += entry_bytes + + _flush(payload) # send remaining def _update_model_params(self, model: nn.Module, chunk_size: int | None = None): diff --git a/src/axolotl/prompt_strategies/ebft/__init__.py b/src/axolotl/prompt_strategies/ebft/__init__.py new file mode 100644 index 000000000..db46af005 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/__init__.py @@ -0,0 +1,9 @@ +""" +module for EBFT style dataset transform strategies +""" + +from functools import partial + +from ..base import load as load_base + +load = partial(load_base, module_base="axolotl.prompt_strategies.ebft") diff --git a/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py b/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py new file mode 100644 index 000000000..26982c183 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py @@ -0,0 +1,129 @@ +""" +Dataset transform for multi-turn chat data with structured EBFT (vLLM mode). + +Three variants: + +1. `transform` — Uses the FIRST assistant turn as the generation target. + Passes remaining turns as `remaining_turns` for sequential rollout. + The trainer generates turn 1 via GRPO/vLLM, then sequentially generates + subsequent assistant turns, comparing the full conversation to GT. + +2. `transform_last_turn` — Uses the LAST assistant turn as the target. + Simplest approach: the full conversation history is the prompt. + +3. `transform_all_turns` — Explodes each conversation into N examples + (one per assistant turn). Each turn is an independent training example. + Use with batched=True. + +Supports OpenAI chat format: + {"messages": [{"role": ..., "content": ...}, ...]} +""" + + +def transform(cfg, **kwargs): + """Multi-turn with sequential rollout. + + Returns the first assistant turn as ground_truth, plus remaining_turns + for the trainer to do sequential rollout generation. + """ + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + if not messages: + return {"prompt": [], "ground_truth": ""} + + # Split at first assistant turn + prompt_msgs = [] + first_gt = None + remaining = [] + + found_first = False + for msg in messages: + if msg["role"] == "assistant" and not found_first: + first_gt = msg["content"] + found_first = True + elif found_first: + remaining.append(msg) + else: + prompt_msgs.append(msg) + + if first_gt is None: + return {"prompt": prompt_msgs, "ground_truth": ""} + + # Store only the first assistant turn as ground_truth. The full multi-turn + # GT is reconstructed in the reward function via chat template rendering + # (using remaining_turns), which preserves role markers between turns. + return { + "prompt": prompt_msgs, + "ground_truth": first_gt, + "remaining_turns": remaining, + } + + return transform_fn, { + "remove_columns": "__all__", + } + + +def transform_last_turn(cfg, **kwargs): + """Single-turn: use the last assistant turn as the generation target.""" + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + if not messages: + return {"prompt": [], "ground_truth": ""} + + # Find all assistant turns + history = [] + last_prompt = [] + last_gt = "" + for msg in messages: + if msg["role"] == "assistant": + last_prompt = list(history) + last_gt = msg["content"] + history.append(msg) + + return { + "prompt": last_prompt, + "ground_truth": last_gt, + } + + return transform_fn, { + "remove_columns": "__all__", + } + + +def transform_all_turns(cfg, **kwargs): + """Explode: one example per assistant turn. + + Use with datasets.map(batched=True) to produce N examples from + each N-turn conversation. + + Usage in YAML: + type: ebft_chat_multiturn.transform_all_turns + """ + + def transform_fn(examples, tokenizer=None): + all_prompts = [] + all_ground_truths = [] + + messages_list = examples.get("messages", examples.get("conversations", [])) + + for messages in messages_list: + history = [] + for msg in messages: + if msg["role"] == "assistant": + all_prompts.append(list(history)) + all_ground_truths.append(msg["content"]) + history.append(msg) + + return { + "prompt": all_prompts, + "ground_truth": all_ground_truths, + } + + return transform_fn, { + "remove_columns": "__all__", + "batched": True, + } diff --git a/src/axolotl/prompt_strategies/ebft/ebft_opencode.py b/src/axolotl/prompt_strategies/ebft/ebft_opencode.py new file mode 100644 index 000000000..930d314a1 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/ebft_opencode.py @@ -0,0 +1,20 @@ +""" +Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode. + +Maps the dataset's `input` (prompt) and `output` (code solution) fields +to the format expected by the EBFT trainer (prompt + ground_truth). +""" + + +def transform(cfg, **kwargs): + def transform_fn(example, tokenizer=None): + return { + "prompt": [ + {"role": "user", "content": example["input"]}, + ], + "ground_truth": example["output"], + } + + return transform_fn, { + "remove_columns": "__all__", + } diff --git a/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py b/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py new file mode 100644 index 000000000..7c756d158 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/ebft_reasoning.py @@ -0,0 +1,319 @@ +""" +Dataset transform for reasoning/thinking datasets with EBFT. + +Handles datasets where assistant responses contain ... reasoning +traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs). + +Two variants: + +1. `transform` — For structured EBFT (vLLM mode): + Returns prompt + ground_truth with thinking tags preserved. + Feature matching compares full responses (thinking + answer). + +2. `transform_answer_only` — For structured EBFT (vLLM mode): + Strips ... from ground_truth, so feature matching + only scores the final answer portion. Use when reasoning chains + can vary but the answer should match. + +3. `transform_strided` — For strided EBFT: + Tokenizes the full conversation with thinking traces. + Optionally masks thinking tokens from CE loss (labels=-100 for think spans) + while still placing anchors in thinking regions for feature matching. + +All variants work with OpenAI chat format: + {"messages": [{"role": "...", "content": "...Answer"}]} +""" + +import re + + +def _strip_thinking(text: str) -> str: + """Remove ... blocks from text.""" + return re.sub(r".*?\s*", "", text, flags=re.DOTALL).strip() + + +def _extract_thinking(text: str) -> tuple[str, str]: + """Split text into (thinking, answer) parts.""" + match = re.search(r"(.*?)\s*(.*)", text, flags=re.DOTALL) + if match: + return match.group(1).strip(), match.group(2).strip() + return "", text.strip() + + +def transform(cfg, **kwargs): + """Full response including thinking traces for feature matching. + + For datasets where assistant content has ... tags in the + content field. The ground_truth includes the full content (thinking + answer). + """ + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + prompt_msgs_snapshot = None + ground_truth = "" + for msg_idx, msg in enumerate(messages): + if msg["role"] == "assistant": + prompt_msgs_snapshot = list(messages[:msg_idx]) + ground_truth = msg["content"] + + return { + "prompt": prompt_msgs_snapshot + if prompt_msgs_snapshot is not None + else messages[:-1], + "ground_truth": ground_truth, + } + + return transform_fn, {"remove_columns": "__all__"} + + +def transform_split_thinking(cfg, **kwargs): + """Split tags into reasoning_content field for native chat template handling. + + For datasets where thinking is embedded in the content field as .... + Splits it into separate reasoning_content and content fields so the model's + chat template can format it natively (e.g., Qwen3.5's reasoning_content support). + + The prompt messages are passed through with reasoning_content properly split, + so vLLM generation with enable_thinking=true produces comparable outputs. + The ground_truth is the full assistant response (thinking + answer) for + feature matching. + + Also works for: + - ... tags + - <|begin_of_thought|>...<|end_of_thought|> tags + """ + _THINKING_PAIRS = [ + ("", ""), + ("", ""), + ("<|begin_of_thought|>", "<|end_of_thought|>"), + ] + + def _split_msg_thinking(msg): + """Split thinking from assistant message content into reasoning_content. + + Always includes reasoning_content key on assistant messages (empty string + if no thinking tags found) to ensure consistent HF dataset schema across + all examples in a batch. + """ + if msg["role"] != "assistant": + return msg + content = msg.get("content", "") + # Already has reasoning_content — pass through + if "reasoning_content" in msg: + return msg + for open_tag, close_tag in _THINKING_PAIRS: + if open_tag in content and close_tag in content: + start = content.find(open_tag) + end = content.find(close_tag) + thinking = content[start + len(open_tag) : end].strip() + answer = content[end + len(close_tag) :].strip() + return { + **msg, + "reasoning_content": thinking, + "content": answer, + } + # No thinking tags — still add reasoning_content for schema consistency + return {**msg, "reasoning_content": ""} + + def _normalize_msg(msg): + """Ensure every message has {role, content, reasoning_content} for HF schema consistency.""" + return { + "role": msg.get("role", ""), + "content": msg.get("content", ""), + "reasoning_content": msg.get("reasoning_content", ""), + } + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + # Split thinking in all assistant messages, then normalize schema + split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages] + + # Build prompt (all messages except last assistant) and ground_truth + prompt_msgs = [] + prompt_msgs_snapshot = None + ground_truth = "" + for msg in split_messages: + if msg["role"] == "assistant": + prompt_msgs_snapshot = list(prompt_msgs) + # ground_truth is the FULL content for feature matching + thinking = msg.get("reasoning_content", "") + answer = msg.get("content", "") + if thinking: + ground_truth = f"\n{thinking}\n\n\n{answer}" + else: + ground_truth = answer + prompt_msgs.append(msg) + + return { + "prompt": prompt_msgs_snapshot + if prompt_msgs_snapshot is not None + else split_messages[:-1], + "ground_truth": ground_truth, + } + + return transform_fn, {"remove_columns": "__all__"} + + +def transform_answer_only(cfg, **kwargs): + """Strip thinking from ground_truth — match features on answer only.""" + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + prompt_msgs = [] + prompt_msgs_snapshot = None + ground_truth = "" + for msg in messages: + if msg["role"] == "assistant": + prompt_msgs_snapshot = list(prompt_msgs) + ground_truth = _strip_thinking(msg["content"]) + prompt_msgs.append(msg) + + return { + "prompt": prompt_msgs_snapshot + if prompt_msgs_snapshot is not None + else messages[:-1], + "ground_truth": ground_truth, + } + + return transform_fn, {"remove_columns": "__all__"} + + +def transform_strided(cfg, **kwargs): + """For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss. + + Config options (via cfg): + - ebft.mask_thinking_ce: bool (default False) + If True, set labels=-100 for tokens inside ... blocks. + Feature matching still uses these positions (anchors are placed everywhere + in the completion span). Only CE auxiliary loss is affected. + """ + seq_len = cfg.sequence_len + mask_thinking = False + if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"): + mask_thinking = cfg.ebft.mask_thinking_ce + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + if tokenizer is None: + for m in messages: + if m.get("role") == "user": + return {"prompt": m["content"]} + return {"prompt": str(messages)} + + pad_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) + + # Tokenize the full conversation with the chat template + full_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + full_enc = tokenizer( + full_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + ) + input_ids = full_enc["input_ids"] + + # Build labels: -100 for non-assistant tokens + labels = [-100] * len(input_ids) + + # Find assistant turn boundaries using incremental tokenization. + # Only the FINAL assistant turn is marked as trainable. + prefix_messages = [] + final_start = None + final_end = None + for msg in messages: + if msg["role"] == "assistant": + prefix_text = tokenizer.apply_chat_template( + prefix_messages, + tokenize=False, + add_generation_prompt=True, + ) + prefix_ids = tokenizer( + prefix_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + start = len(prefix_ids) + + prefix_messages.append(msg) + with_turn_text = tokenizer.apply_chat_template( + prefix_messages, + tokenize=False, + add_generation_prompt=False, + ) + with_turn_ids = tokenizer( + with_turn_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + end = len(with_turn_ids) + + # Record this turn's boundaries; only the last one will be used + final_start = start + final_end = end + else: + prefix_messages.append(msg) + + # Mark only the final assistant turn as trainable + if final_start is not None and final_end is not None: + for i in range(final_start, min(final_end, len(labels))): + labels[i] = input_ids[i] + + # Optionally mask ... tokens within this turn. + # Find think spans by scanning for and token IDs + # directly in the input_ids (robust to tokenization alignment). + if mask_thinking: + think_open_id = tokenizer.convert_tokens_to_ids("") + think_close_id = tokenizer.convert_tokens_to_ids("") + if think_open_id != tokenizer.unk_token_id: + # Scan from before the assistant turn start to catch + # tags that are part of the template prefix + scan_start = max(0, final_start - 5) + in_think = False + for i in range(scan_start, min(final_end, len(labels))): + if input_ids[i] == think_open_id: + in_think = True + if in_think and i >= final_start: + labels[i] = -100 + if input_ids[i] == think_close_id: + in_think = False + if i >= final_start: + labels[i] = -100 + + # Derive prompt_length + prompt_length = len(input_ids) + for i, lbl in enumerate(labels): + if lbl != -100: + prompt_length = i + break + + # Pad + pad_len = seq_len - len(input_ids) + attention_mask = [1] * len(input_ids) + [0] * pad_len + labels = labels + [-100] * pad_len + input_ids = input_ids + [pad_id] * pad_len + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompt_length": prompt_length, + } + + return transform_fn, {"remove_columns": "__all__"} diff --git a/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py b/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py new file mode 100644 index 000000000..f1d0c01f9 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py @@ -0,0 +1,110 @@ +""" +Dataset transform for multi-turn chat data with strided EBFT. + +Tokenizes conversations using the model's chat template, producing input_ids +with labels=-100 for system/user turns and real labels for assistant turns. +The strided trainer places anchors only within assistant completion spans. + +Works with datasets in OpenAI chat format: + [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] +""" + + +def transform(cfg, **kwargs): + seq_len = cfg.sequence_len + + def transform_fn(example, tokenizer=None): + messages = example.get("messages", example.get("conversations", [])) + + if tokenizer is None: + # For preview: just return the first user message + for m in messages: + if m.get("role") == "user": + return {"prompt": m["content"]} + return {"prompt": str(messages)} + + pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id + + # Tokenize the full conversation with the chat template + full_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + full_enc = tokenizer( + full_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + ) + input_ids = full_enc["input_ids"] + + # Build labels: -100 for everything except assistant turns. + # Strategy: tokenize incrementally to find assistant turn boundaries. + labels = [-100] * len(input_ids) + + # Tokenize prefix up to each assistant turn to find boundaries + prefix_messages = [] + for msg in messages: + if msg["role"] == "assistant": + # Tokenize prefix (everything before this assistant turn + generation prompt) + prefix_text = tokenizer.apply_chat_template( + prefix_messages, + tokenize=False, + add_generation_prompt=True, + ) + prefix_ids = tokenizer( + prefix_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + start = len(prefix_ids) + + # Tokenize prefix + this assistant turn + prefix_messages.append(msg) + with_turn_text = tokenizer.apply_chat_template( + prefix_messages, + tokenize=False, + add_generation_prompt=False, + ) + with_turn_ids = tokenizer( + with_turn_text, + truncation=True, + max_length=seq_len, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + end = len(with_turn_ids) + + # Mark assistant tokens as trainable + for i in range(start, min(end, len(labels))): + labels[i] = input_ids[i] + else: + prefix_messages.append(msg) + + # Derive prompt_length as the position of the first non-masked label + prompt_length = len(input_ids) # default: all masked + for i, lbl in enumerate(labels): + if lbl != -100: + prompt_length = i + break + + # Pad to seq_len + pad_len = seq_len - len(input_ids) + attention_mask = [1] * len(input_ids) + [0] * pad_len + labels = labels + [-100] * pad_len + input_ids = input_ids + [pad_id] * pad_len + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompt_length": prompt_length, + } + + return transform_fn, { + "remove_columns": "__all__", + } diff --git a/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py b/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py new file mode 100644 index 000000000..767650575 --- /dev/null +++ b/src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py @@ -0,0 +1,80 @@ +""" +Dataset transform for structured (prompt, completion) data with strided EBFT. + +Tokenizes prompt and completion separately, concatenates into a single +input_ids sequence, and marks prompt tokens with labels=-100 so the +strided trainer knows where to place anchors (completion span only). + +Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct). +""" + + +def transform(cfg, **kwargs): + seq_len = cfg.sequence_len + + def transform_fn(example, tokenizer=None): + # Extract prompt and completion from the example + prompt_text = example.get( + "input", example.get("prompt", example.get("question", "")) + ) + completion_text = example.get( + "output", example.get("completion", example.get("answer", "")) + ) + + if tokenizer is None: + return {"prompt": prompt_text} + + pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id + + # Tokenize prompt and completion separately + prompt_enc = tokenizer( + prompt_text, + truncation=False, + add_special_tokens=True, + return_tensors=None, + ) + completion_enc = tokenizer( + completion_text, + truncation=False, + add_special_tokens=False, + return_tensors=None, + ) + + prompt_ids = prompt_enc["input_ids"] + completion_ids = completion_enc["input_ids"] + + # Truncate to fit within seq_len (prioritize keeping prompt + some completion) + total_len = len(prompt_ids) + len(completion_ids) + if total_len > seq_len: + # Truncate completion first, then prompt if needed + max_completion = seq_len - len(prompt_ids) + if max_completion < 1: + # Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token + prompt_ids = prompt_ids[: seq_len - 1] + completion_ids = completion_ids[:1] + else: + completion_ids = completion_ids[:max_completion] + + input_ids = prompt_ids + completion_ids + prompt_length = len(prompt_ids) + + # Labels: -100 for prompt tokens, input_ids for completion tokens + labels = [-100] * prompt_length + completion_ids + + # Pad to seq_len + pad_len = seq_len - len(input_ids) + attention_mask = [1] * len(input_ids) + [0] * pad_len + labels = labels + [-100] * pad_len + input_ids = input_ids + [pad_id] * pad_len + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompt_length": prompt_length, + } + + # Signal to remove all original columns (filtered to existing ones at map time) + return transform_fn, { + "remove_columns": "__all__", + } diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index 9ce4d2771..f4fcfa190 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -241,6 +241,23 @@ def main(script_args: ScriptArguments): app = FastAPI(lifespan=lifespan) + # --- Access logging middleware --- + import time as _time + + @app.middleware("http") + async def access_log_middleware(request, call_next): + t0 = _time.monotonic() + response = await call_next(request) + elapsed = _time.monotonic() - t0 + logger.info( + "%s %s %d %.3fs", + request.method, + request.url.path, + response.status_code, + elapsed, + ) + return response + # --- Active LoRA state (shared across endpoints via closure) --- active_lora: dict = {"request": None} @@ -300,7 +317,11 @@ def main(script_args: ScriptArguments): import vllm from packaging.version import Version - from vllm.sampling_params import GuidedDecodingParams + + try: + from vllm.sampling_params import GuidedDecodingParams + except ImportError: + GuidedDecodingParams = None # not available in vLLM 0.17+ images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item] prompts: list[dict[str, Any]] = [] @@ -362,7 +383,12 @@ def main(script_args: ScriptArguments): } conn.send({"type": "call", "method": "generate", "kwargs": kwargs}) - all_outputs = [conn.recv() for conn in connections] + # Use run_in_executor so blocking recv() doesn't freeze the event loop + # (allows /set_lora_adapter/ and other endpoints to be served concurrently) + loop = asyncio.get_running_loop() + all_outputs = await asyncio.gather( + *(loop.run_in_executor(None, conn.recv) for conn in connections) + ) all_outputs = [ o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c ] @@ -404,7 +430,10 @@ def main(script_args: ScriptArguments): } conn.send({"type": "call", "method": "chat", "kwargs": kwargs}) - all_outputs = [conn.recv() for conn in connections] + loop = asyncio.get_running_loop() + all_outputs = await asyncio.gather( + *(loop.run_in_executor(None, conn.recv) for conn in connections) + ) all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] all_outputs = list(chain.from_iterable(all_outputs)) @@ -474,11 +503,51 @@ def main(script_args: ScriptArguments): ) return {"message": f"Batch update for {len(params_list)} params"} + class HTTPWeightUpdateRequest(BaseModel): + """Weight update via HTTP (no NCCL needed).""" + + params: list[ + dict + ] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}] + + @app.post("/http_update_weights/") + async def http_update_weights(request: HTTPWeightUpdateRequest): + """Update model weights via HTTP — no NCCL communicator required. + + Tensor data is sent as base64-encoded raw bytes in the request body. + Slower than NCCL for large models but works without cross-process setup. + """ + from axolotl.utils.weight_serde import ( + decode_from_http, + encode_for_ipc, + ) + + weights_to_load = [decode_from_http(p) for p in request.params] + + # Send all weights in a single IPC call. Tensors don't survive + # vLLM's multiproc IPC, so serialize as raw bytes + metadata. + param_entries = [ + encode_for_ipc(name, weight) for name, weight in weights_to_load + ] + kwargs = { + "method": "http_load_weights_batch", + "kwargs": {"params": param_entries}, + } + msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs} + loop = asyncio.get_running_loop() + await asyncio.gather( + *(loop.run_in_executor(None, c.send, msg) for c in connections) + ) + return {"message": f"HTTP weight update for {len(weights_to_load)} params"} + @app.post("/reset_prefix_cache/") async def reset_prefix_cache(): for conn in connections: conn.send({"type": "call", "method": "reset_prefix_cache"}) - results = [conn.recv() for conn in connections] + loop = asyncio.get_running_loop() + results = await asyncio.gather( + *(loop.run_in_executor(None, conn.recv) for conn in connections) + ) return {"message": f"Reset prefix cache: {all(results)}"} @app.post("/close_communicator/") diff --git a/src/axolotl/scripts/vllm_worker_ext.py b/src/axolotl/scripts/vllm_worker_ext.py index 386460df1..11f8e6ceb 100644 --- a/src/axolotl/scripts/vllm_worker_ext.py +++ b/src/axolotl/scripts/vllm_worker_ext.py @@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension): model = self.model_runner.model params_dict = dict(model.named_parameters()) + # Handle VLM models where trainer and vLLM use different prefixes. + # Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..." + # vLLM (Qwen3.5): "language_model.model.layers.X..." + if name not in params_dict: + # Try common prefix remappings + for src_prefix, dst_prefix in [ + ("model.language_model.layers.", "language_model.model.layers."), + ("model.layers.", "language_model.model.layers."), + ]: + if name.startswith(src_prefix): + name = dst_prefix + name[len(src_prefix) :] + break + # Check if this is a simple direct param (exists as-is) if name in params_dict: params_dict[name].data.copy_(weight.to(params_dict[name].dtype)) @@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension): return # Fallback: try load_weights (may work for non-stacked params) - logger.warning("Falling back to load_weights for param: %s", name) + # Log the actual param names available for debugging + sample_keys = [ + k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k + ][:3] + logger.warning( + "Falling back to load_weights for param: %s (sample vLLM keys: %s)", + name, + sample_keys, + ) model.load_weights(weights=[(name, weight)]) def update_named_param(self, name, dtype, shape): @@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension): # Load weights using direct set (handles stacked params) for name, weight in weights_to_load: self._direct_set_weight(name, weight) + + def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]): + """Load weights received via HTTP (no NCCL needed).""" + for name, weight in weights: + self._direct_set_weight(name, weight.to(self.device)) + + def http_load_weight(self, **kwargs): + """Load a single weight received via HTTP (no NCCL needed). + + Reconstructs the tensor from raw bytes since tensors don't survive + vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights`` + which handles TP sharding and stacked-param packing automatically. + """ + from axolotl.utils.weight_serde import decode_from_ipc + + name, weight = decode_from_ipc(kwargs) + model = self.model_runner.model + model.load_weights(weights=[(name, weight)]) + + def http_load_weights_batch(self, params: list[dict]): + """Load multiple weights in a single IPC call. + + Uses vLLM's ``load_weights`` which handles TP sharding automatically. + """ + from axolotl.utils.weight_serde import decode_from_ipc + + model = self.model_runner.model + weights = [decode_from_ipc(p) for p in params] + model.load_weights(weights=weights) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 522dd7e28..774aa1cec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -138,7 +138,11 @@ def setup_reference_model( model_ref = None # explicit setting to None else: reference_model: bool = True - if cfg.rl == RLType.GRPO and cfg.trl.beta == 0: + trl_cfg = getattr(cfg, "trl", None) + if ( + cfg.rl in {RLType.GRPO, RLType.EBFT} + and getattr(trl_cfg, "beta", 0) == 0 + ): reference_model = False # load the model again for model_ref/baseline model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model) @@ -206,7 +210,7 @@ def execute_training( gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, - gather_outputs=cfg.rl is RLType.GRPO, + gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT}, device_mesh=trainer.accelerator.torch_device_mesh, ) ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 36370ef13..afdb7f2a2 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): ].append(pred_step_text) row_index += 1 if logger == "wandb": - # type: ignore[attr-defined] - wandb.run.log( + wandb.run.log( # type: ignore[attr-defined] { f"{name} - Predictions vs Ground Truth": pd.DataFrame( table_data @@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" ) as temp_file: copyfile(self.axolotl_config_path, temp_file.name) - artifact = wandb.Artifact( - f"config-{wandb.run.id}", type="axolotl-config" + artifact = wandb.Artifact( # type: ignore[attr-defined] + f"config-{wandb.run.id}", # type: ignore[attr-defined] + type="axolotl-config", ) artifact.add_file(temp_file.name) - wandb.log_artifact(artifact) - wandb.save(temp_file.name) + wandb.log_artifact(artifact) # type: ignore[attr-defined] + wandb.save(temp_file.name) # type: ignore[attr-defined] LOG.info( "The Axolotl config has been saved to the WandB run under files." ) @@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): temp_ct_file.write(str(chat_tpl)) temp_ct_file.flush() - artifact = wandb.Artifact( - f"chat-template-{wandb.run.id}", type="jinja-template" + artifact = wandb.Artifact( # type: ignore[attr-defined] + f"chat-template-{wandb.run.id}", # type: ignore[attr-defined] + type="jinja-template", ) artifact.add_file(temp_ct_file.name) - wandb.log_artifact(artifact) - wandb.save(temp_ct_file.name) + wandb.log_artifact(artifact) # type: ignore[attr-defined] + wandb.save(temp_ct_file.name) # type: ignore[attr-defined] LOG.info( "The chat_template_jinja has been saved to the WandB run under files." ) @@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): else: skip_upload = True if not skip_upload: - artifact = wandb.Artifact( - f"deepspeed-config-{wandb.run.id}", + artifact = wandb.Artifact( # type: ignore[attr-defined] + f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined] type="deepspeed-config", ) artifact.add_file(temp_file.name) - wandb.log_artifact(artifact) - wandb.save(temp_file.name) + wandb.log_artifact(artifact) # type: ignore[attr-defined] + wandb.save(temp_file.name) # type: ignore[attr-defined] LOG.info( "The DeepSpeed config has been saved to the WandB run under files." ) diff --git a/src/axolotl/utils/callbacks/generation.py b/src/axolotl/utils/callbacks/generation.py index 439258c8b..da36b9ad0 100644 --- a/src/axolotl/utils/callbacks/generation.py +++ b/src/axolotl/utils/callbacks/generation.py @@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback): if not getattr(cfg, "generate_samples", False): return - dataloader = None - try: - if getattr(self.trainer, "eval_dataset", None) is not None: - dataloader = self.trainer.get_eval_dataloader() - LOG.info( - f"Using eval dataloader for generation at step {state.global_step}" - ) - except Exception as e: - LOG.warning(f"Could not get eval dataloader: {e}") - dataloader = None - - if dataloader is None: - dataloader = self.trainer.get_train_dataloader() + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() LOG.info( - f"Using train dataloader for generation at step {state.global_step}" + f"Using eval dataloader for generation at step {state.global_step}" ) + except Exception as e: + LOG.warning(f"Could not get eval dataloader: {e}") + dataloader = None - samples = generate_samples( - model=self.trainer.model, - tokenizer=self.trainer.processing_class, - dataloader=dataloader, - num_generation_samples=getattr(cfg, "num_generation_samples", 3), - max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), - temperature=getattr(cfg, "generation_temperature", 0.7), - top_p=getattr(cfg, "generation_top_p", None), - top_k=getattr(cfg, "generation_top_k", None), - do_sample=getattr(cfg, "generation_do_sample", True), - prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + LOG.info( + f"Using train dataloader for generation at step {state.global_step}" ) - self._log_samples(samples, state.global_step) + + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), + temperature=getattr(cfg, "generation_temperature", 0.7), + top_p=getattr(cfg, "generation_top_p", None), + top_k=getattr(cfg, "generation_top_k", None), + do_sample=getattr(cfg, "generation_do_sample", True), + prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + ) + self._log_samples(samples, state.global_step) def _log_samples(self, samples: list, step: int): """Log generated samples to console and W&B.""" @@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback): try: import wandb - if wandb.run is not None: - wandb.log( + if wandb.run is not None: # type: ignore[attr-defined] + wandb.log( # type: ignore[attr-defined] { - f"samples/sample_{i + 1}": wandb.Html( + f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined] f"
{wandb_text}
" ) }, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 2c386f35e..ef91e1124 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer from axolotl.loaders import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.ebft import load as load_ebft from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.lock import FileLockLoader @@ -173,7 +174,7 @@ def _drop_long_sequences( return (len_prompt + len_completion) <= sequence_len - if rl in {RLType.GRPO, RLType.GDPO}: + if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}: return True raise ValueError("Unknown RL type") @@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i) elif cfg.rl is RLType.KTO: ds_transform_fn = load_kto(_type, cfg, dataset_idx=i) + elif cfg.rl is RLType.EBFT: + ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i) else: ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i) map_kwargs: dict[str, Any] = {} if isinstance(ds_transform_fn, tuple): ds_transform_fn, map_kwargs = ds_transform_fn + # Handle remove_columns: "__all__" removes all original columns, + # or filter a list to only columns that exist in the dataset + if "remove_columns" in map_kwargs: + ds_columns = ( + dataset.column_names + if isinstance(dataset, Dataset) + else dataset[split].column_names + if isinstance(dataset, DatasetDict) + else [] + ) + if map_kwargs["remove_columns"] == "__all__": + map_kwargs["remove_columns"] = list(ds_columns) + else: + map_kwargs["remove_columns"] = [ + c for c in map_kwargs["remove_columns"] if c in ds_columns + ] split_datasets[i] = _map_dataset( cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 34fd9ba2c..982e3e419 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig LOG = get_logger(__name__) +class EBFTConfig(BaseModel): + """Configuration for Energy-Based Fine-Tuning (EBFT)""" + + feature_layers: list[float] = Field( + default=[0.25, 0.5, 0.75], + json_schema_extra={ + "description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])" + }, + ) + embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = ( + Field( + default="last_token", + json_schema_extra={ + "description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'" + }, + ) + ) + use_whitening: bool = Field( + default=False, + json_schema_extra={"description": "Apply SVD whitening to feature embeddings"}, + ) + alignment_coef: float = Field( + default=1.0, + json_schema_extra={ + "description": "Coefficient for alignment reward (cosine similarity with ground truth)" + }, + ) + diversity_coef: float = Field( + default=1.0, + json_schema_extra={ + "description": "Coefficient for diversity penalty (pairwise similarity between samples)" + }, + ) + ce_coef: float = Field( + default=0.0, + json_schema_extra={ + "description": "Cross-entropy loss coefficient on ground-truth tokens" + }, + ) + adaptive_max_tokens: bool = Field( + default=True, + json_schema_extra={ + "description": "Set per-batch max_tokens based on ground-truth length" + }, + ) + gt_length_multiplier: float = Field( + default=1.5, + ge=0.1, + json_schema_extra={ + "description": "Multiplier for ground-truth token count when computing adaptive max_tokens" + }, + ) + + # Strided mode fields (for unstructured text) + mode: Literal["structured", "strided"] = Field( + default="structured", + json_schema_extra={ + "description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)" + }, + ) + stride: int = Field( + default=8, + ge=1, + json_schema_extra={"description": "Stride between anchor points (tokens)"}, + ) + context_length: int = Field( + default=8, + ge=1, + json_schema_extra={"description": "Context window size per block"}, + ) + generate_max_len: int = Field( + default=8, + ge=1, + json_schema_extra={"description": "Tokens to generate per block"}, + ) + n_samples_per_prompt: int = Field( + default=4, + ge=1, + json_schema_extra={"description": "Independent rollouts per document"}, + ) + temperature: float = Field( + default=0.6, + ge=0.0, + json_schema_extra={ + "description": "Sampling temperature for strided generation" + }, + ) + top_p: float = Field( + default=1.0, + ge=0.0, + le=1.0, + json_schema_extra={"description": "Top-p nucleus sampling threshold"}, + ) + rl_coef: float = Field( + default=1.0, + json_schema_extra={"description": "RL policy gradient loss coefficient"}, + ) + advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field( + default="rloo", + json_schema_extra={ + "description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'" + }, + ) + min_completion_prefix: int = Field( + default=0, + ge=0, + json_schema_extra={ + "description": "Minimum tokens into completion before placing anchors. " + "Skips anchors too close to the prompt boundary where features are dominated by prompt context." + }, + ) + + class AxolotlInputConfig( ModelInputConfig, ModelOutputConfig, @@ -131,7 +244,7 @@ class AxolotlInputConfig( rl: RLType | None = Field( default=None, json_schema_extra={ - "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'" + "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'" }, ) trl: TRLConfig | None = Field( @@ -140,6 +253,12 @@ class AxolotlInputConfig( vllm: VllmConfig | None = Field( default_factory=lambda: VllmConfig(), ) + ebft: EBFTConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration for Energy-Based Fine-Tuning (EBFT)" + }, + ) qat: QATConfig | None = None quantization: PTQConfig | None = None reward_model: bool | None = Field( diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 40fa314f4..7ffa793f2 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -35,6 +35,7 @@ class RLType(str, Enum): ORPO = "orpo" KTO = "kto" SIMPO = "simpo" + EBFT = "ebft" class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 2d7c36f96..4ef42db66 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -1,6 +1,6 @@ """Pydantic models for TRL trainer configuration""" -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, Field @@ -133,6 +133,20 @@ class TRLConfig(BaseModel): "description": "Penalty for tokens that appear in prompt and generated text." }, ) + generation_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Additional generation parameters passed to vLLM SamplingParams. " + "Useful for stop_token_ids, seed, frequency_penalty, etc." + }, + ) + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Additional kwargs for the chat template. " + "E.g., {enable_thinking: false} for Qwen3.5 models." + }, + ) num_iterations: int | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index c902d8703..c7eeb6fa4 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1482,6 +1482,124 @@ class DistributedValidationMixin: return self +class EBFTValidationMixin: + """Validation for EBFT (Energy-Based Fine-Tuning) configuration.""" + + @model_validator(mode="before") + @classmethod + def check_ebft_config_required(cls, data): + """rl: ebft requires an ebft config section.""" + if data.get("rl") == "ebft" and not data.get("ebft"): + raise ValueError( + "`ebft` config section is required when `rl: ebft` is set. " + "Add an `ebft:` section with at least `mode: structured` or `mode: strided`." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_ebft_torch_compile(cls, data): + """torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles + and CheckpointErrors. The flex_attention kernel compiles itself internally — + whole-model torch.compile is not needed and actively harmful.""" + if ( + data.get("rl") == "ebft" + and data.get("torch_compile") is True + and data.get("ebft", {}).get("mode") == "strided" + ): + if data.get("gradient_checkpointing"): + raise ValueError( + "EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` " + "causes CheckpointError (BlockMask metadata mismatch during recomputation). " + "Remove `torch_compile` — the flex_attention kernel compiles itself internally." + ) + LOG.warning( + "EBFT strided mode: `torch_compile: true` causes dynamo recompiles from " + "variable sequence lengths across steps. Consider removing it — " + "flex_attention compiles itself internally." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_ebft_gradient_checkpointing_reentrant(cls, data): + """flex_attention + non-reentrant gradient checkpointing causes CheckpointError.""" + if ( + data.get("rl") == "ebft" + and data.get("ebft", {}).get("mode") == "strided" + and data.get("flex_attention") + and data.get("gradient_checkpointing") + ): + gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} + if not gc_kwargs.get("use_reentrant"): + LOG.warning( + "EBFT strided mode with flex_attention: setting `use_reentrant: true` in " + "gradient_checkpointing_kwargs (required for flex_attention compatibility). " + "Non-reentrant checkpointing causes CheckpointError with BlockMask metadata." + ) + if data.get("gradient_checkpointing_kwargs") is None: + data["gradient_checkpointing_kwargs"] = {} + data["gradient_checkpointing_kwargs"]["use_reentrant"] = True + return data + + @model_validator(mode="before") + @classmethod + def check_ebft_activation_offloading(cls, data): + """activation_offloading replaces gradient checkpointing with FSDP-style wrapping, + which conflicts with flex_attention's use_reentrant requirement.""" + if ( + data.get("rl") == "ebft" + and data.get("ebft", {}).get("mode") == "strided" + and data.get("activation_offloading") is True + and data.get("flex_attention") + ): + raise ValueError( + "EBFT strided mode: `activation_offloading: true` is incompatible with " + "`flex_attention: true`. Activation offloading replaces gradient checkpointing " + "with FSDP-style wrapping that conflicts with flex_attention's reentrant " + "checkpoint requirement. Remove `activation_offloading` — the strided trainer " + "uses micro-batched forward passes for memory efficiency instead." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_ebft_strided_sequence_len(cls, data): + """Warn if sequence_len is too large for single-GPU strided EBFT.""" + if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided": + return data + ebft = data.get("ebft", {}) + seq_len = data.get("sequence_len", 512) + n_samples = ebft.get("n_samples_per_prompt", 4) + gen_len = ebft.get("generate_max_len", 8) + stride = ebft.get("stride", 8) + ctx_len = ebft.get("context_length", 8) + max_blocks = (seq_len - gen_len - ctx_len) // stride + 1 + full_seq = seq_len + max_blocks * gen_len + # Rough estimate: 8.7 GB per sample at S=3900 for 1B model + if full_seq * n_samples > 20000: + LOG.warning( + f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = " + f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM " + f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_ebft_strided_dataset_split(cls, data): + """Warn about the common `train_on_split` mistake (silently ignored by schema).""" + datasets = data.get("datasets", []) + for ds in datasets or []: + if isinstance(ds, dict) and ds.get("train_on_split"): + LOG.warning( + f"Dataset has `train_on_split: {ds['train_on_split']}` — this field " + f"is not recognized and will be silently ignored. " + f"Use `split: {ds['train_on_split']}` instead." + ) + return data + + class GRPOVllmValidationMixin: """Validation mixin for vllm when using GRPO.""" @@ -1507,6 +1625,7 @@ class ValidationMixin( PretrainingValidationMixin, ModelCompatibilityValidationMixin, ComplexValidationMixin, + EBFTValidationMixin, GRPOVllmValidationMixin, ): """Full validation mixin for Axolotl configuration.""" diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py index c0aa48d66..5198d4173 100644 --- a/src/axolotl/utils/schemas/vllm.py +++ b/src/axolotl/utils/schemas/vllm.py @@ -57,6 +57,13 @@ class VllmConfig(BaseModel): default=None, json_schema_extra={"description": "Reasoning parser for VLLM"}, ) + enforce_eager: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Disable CUDA graph capture in vLLM. Required for models with " + "causal_conv1d (e.g., Qwen3.5 hybrid linear attention)." + }, + ) serve_module: str | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/weight_serde.py b/src/axolotl/utils/weight_serde.py new file mode 100644 index 000000000..d4b804681 --- /dev/null +++ b/src/axolotl/utils/weight_serde.py @@ -0,0 +1,94 @@ +"""Serialize / deserialize tensors for HTTP and IPC weight sync. + +NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire +and reconstructed at the destination. All encode/decode helpers live here so +the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and +vllm_worker_ext.py. +""" + +import base64 + +import torch + + +def encode_for_http(name: str, weight: torch.Tensor) -> dict: + """Encode a named parameter for JSON transport over HTTP. + + Returns a dict with keys: name, dtype (original), shape, data (base64). + bf16 tensors are sent as fp16 bytes; the original dtype is preserved in + the ``dtype`` field so the receiver can cast back. + """ + w_cpu = weight.contiguous().cpu() + orig_dtype = str(weight.dtype) + if w_cpu.dtype == torch.bfloat16: + w_cpu = w_cpu.half() + raw = w_cpu.numpy().tobytes() + return { + "name": name, + "dtype": orig_dtype, + "shape": list(weight.shape), + "data": base64.b64encode(raw).decode("ascii"), + } + + +def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]: + """Decode an HTTP-encoded weight entry back to a named tensor. + + Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the + original dtype stored in ``entry["dtype"]``. + """ + target_dtype = getattr(torch, entry["dtype"].split(".")[-1]) + shape = tuple(entry["shape"]) + raw = base64.b64decode(entry["data"]) + + n_elements = 1 + for s in shape: + n_elements *= s + wire_bytes_per_elem = len(raw) // max(n_elements, 1) + if wire_bytes_per_elem == 2: + wire_dtype = torch.float16 + elif wire_bytes_per_elem == 4: + wire_dtype = torch.float32 + else: + wire_dtype = target_dtype + + weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape) + if wire_dtype != target_dtype: + weight = weight.to(target_dtype) + return entry["name"], weight + + +def encode_for_ipc(name: str, weight: torch.Tensor) -> dict: + """Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64). + + Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype + (original), shape. bf16 tensors are serialized as fp16. + """ + w = weight.contiguous() + target_dtype = str(w.dtype).split(".")[-1] + if w.dtype == torch.bfloat16: + w = w.half() + wire_dtype = str(w.dtype).split(".")[-1] + return { + "name": name, + "data": w.numpy().tobytes(), + "dtype": wire_dtype, + "target_dtype": target_dtype, + "shape": list(weight.shape), + } + + +def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]: + """Decode an IPC-encoded weight entry back to a named tensor. + + Handles optional ``target_dtype`` for backward compatibility with older + serve code that may not include it. + """ + wire_dtype = getattr(torch, entry["dtype"]) + weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape( + entry["shape"] + ) + target_dtype = entry.get("target_dtype") + if target_dtype and target_dtype != entry["dtype"]: + weight = weight.to(getattr(torch, target_dtype)) + return entry["name"], weight diff --git a/tests/test_ebft_kernels.py b/tests/test_ebft_kernels.py new file mode 100644 index 000000000..2c05a887a --- /dev/null +++ b/tests/test_ebft_kernels.py @@ -0,0 +1,294 @@ +"""Correctness tests for fused EBFT Triton kernels.""" + +import pytest +import torch +import torch.nn.functional as F + +from axolotl.core.trainers.ebft.kernels import ( + fused_cosine_similarity, + fused_diversity_penalty, + fused_log_softmax_gather, + fused_reinforce_loss, +) + +# Skip all tests if CUDA not available +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for Triton kernels" +) + +DEVICE = "cuda" + + +# --------------------------------------------------------------------------- +# 1. fused_log_softmax_gather +# --------------------------------------------------------------------------- +class TestFusedLogSoftmaxGather: + def _reference(self, logits, labels): + """PyTorch reference: log_softmax + gather.""" + lp = F.log_softmax(logits.float(), dim=-1) + return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + + def test_basic_correctness(self): + B, S, V = 2, 16, 1024 + logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16) + labels = torch.randint(0, V, (B, S), device=DEVICE) + + ref = self._reference(logits, labels) + out = fused_log_softmax_gather(logits, labels) + + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + def test_large_vocab(self): + """Test with realistic vocab size (128K).""" + B, S, V = 1, 8, 128256 + logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16) + labels = torch.randint(0, V, (B, S), device=DEVICE) + + ref = self._reference(logits, labels) + out = fused_log_softmax_gather(logits, labels) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + def test_fp32_input(self): + B, S, V = 2, 8, 512 + logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32) + labels = torch.randint(0, V, (B, S), device=DEVICE) + + ref = self._reference(logits, labels) + out = fused_log_softmax_gather(logits, labels) + + torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5) + + def test_output_is_negative(self): + """log_softmax values should always be <= 0.""" + B, S, V = 4, 32, 2048 + logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16) + labels = torch.randint(0, V, (B, S), device=DEVICE) + + out = fused_log_softmax_gather(logits, labels) + assert (out <= 1e-5).all(), "log_softmax values should be <= 0" + + def test_extreme_logits(self): + """Test numerical stability with very large/small logits.""" + B, S, V = 1, 4, 256 + logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32) + logits[:, 0, :] = 1000.0 # very large + logits[:, 1, :] = -1000.0 # very small + logits[:, 2, 0] = 1000.0 # one hot-ish + labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long) + + ref = self._reference(logits, labels) + out = fused_log_softmax_gather(logits, labels) + + assert torch.isfinite(out).all(), "Should handle extreme values" + torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + """Test with pre-flattened (N, V) input.""" + N, V = 64, 4096 + logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16) + labels = torch.randint(0, V, (N,), device=DEVICE) + + ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0) + out = fused_log_softmax_gather(logits, labels) + + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + +# --------------------------------------------------------------------------- +# 2. fused_reinforce_loss +# --------------------------------------------------------------------------- +class TestFusedReinforceLoss: + def _reference(self, logps, advantages, mask): + """PyTorch reference implementation.""" + loss_per_token = -logps * advantages + return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1) + + def test_basic_correctness(self): + N = 1024 + logps = torch.randn(N, device=DEVICE, dtype=torch.float32) + advs = torch.randn(N, device=DEVICE, dtype=torch.float32) + mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool) + + ref = self._reference(logps, advs, mask) + out = fused_reinforce_loss(logps, advs, mask) + + torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + """Test with (B, S) shaped inputs.""" + B, S = 4, 256 + logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32) + advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32) + mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool) + + ref = self._reference(logps, advs, mask) + out = fused_reinforce_loss(logps, advs, mask) + + torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4) + + def test_all_masked(self): + """All-zero mask should return 0.""" + N = 512 + logps = torch.randn(N, device=DEVICE, dtype=torch.float32) + advs = torch.randn(N, device=DEVICE, dtype=torch.float32) + mask = torch.zeros(N, device=DEVICE, dtype=torch.bool) + + out = fused_reinforce_loss(logps, advs, mask) + assert out.item() == 0.0 + + def test_all_unmasked(self): + N = 512 + logps = torch.randn(N, device=DEVICE, dtype=torch.float32) + advs = torch.randn(N, device=DEVICE, dtype=torch.float32) + mask = torch.ones(N, device=DEVICE, dtype=torch.bool) + + ref = self._reference(logps, advs, mask) + out = fused_reinforce_loss(logps, advs, mask) + + torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4) + + def test_large(self): + """Test with realistic size (4 * 3000 tokens).""" + N = 12000 + logps = torch.randn(N, device=DEVICE, dtype=torch.float32) + advs = torch.randn(N, device=DEVICE, dtype=torch.float32) + mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool) + + ref = self._reference(logps, advs, mask) + out = fused_reinforce_loss(logps, advs, mask) + + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + +# --------------------------------------------------------------------------- +# 3. fused_cosine_similarity +# --------------------------------------------------------------------------- +class TestFusedCosineSimilarity: + def test_basic_correctness(self): + N, D = 64, 256 + a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16) + b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16) + + ref = F.cosine_similarity(a.float(), b.float(), dim=-1) + out = fused_cosine_similarity(a, b) + + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + def test_batched(self): + """Test with (B, N, NB, D) shaped input.""" + B, N, NB, D = 2, 4, 16, 512 + a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16) + b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16) + + ref = F.cosine_similarity(a.float(), b.float(), dim=-1) + out = fused_cosine_similarity(a, b) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + def test_identical_vectors(self): + """Identical vectors should give similarity = 1.""" + N, D = 16, 128 + a = torch.randn(N, D, device=DEVICE, dtype=torch.float32) + + out = fused_cosine_similarity(a, a) + torch.testing.assert_close( + out, + torch.ones(N, device=DEVICE, dtype=torch.float32), + atol=1e-5, + rtol=1e-5, + ) + + def test_orthogonal_vectors(self): + """Orthogonal vectors should give similarity = 0.""" + D = 128 + a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32) + b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32) + a[0, 0] = 1.0 + b[0, 1] = 1.0 + + out = fused_cosine_similarity(a, b) + assert abs(out.item()) < 1e-5 + + def test_opposite_vectors(self): + """Opposite vectors should give similarity = -1.""" + N, D = 8, 64 + a = torch.randn(N, D, device=DEVICE, dtype=torch.float32) + out = fused_cosine_similarity(a, -a) + torch.testing.assert_close( + out, + -torch.ones(N, device=DEVICE, dtype=torch.float32), + atol=1e-5, + rtol=1e-5, + ) + + def test_large_dimension(self): + """Test with large feature dimension (multi-layer concatenated features).""" + N, D = 32, 4608 # 3 layers * 1536 hidden + a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16) + b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16) + + ref = F.cosine_similarity(a.float(), b.float(), dim=-1) + out = fused_cosine_similarity(a, b) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + +# --------------------------------------------------------------------------- +# 4. fused_diversity_penalty +# --------------------------------------------------------------------------- +class TestFusedDiversityPenalty: + def _reference(self, embeddings): + """PyTorch reference: bmm + mask diagonal + mean.""" + B, N, D = embeddings.shape + sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2)) + eye = torch.eye(N, device=embeddings.device, dtype=torch.bool) + sims = sims.masked_fill(eye.unsqueeze(0), 0.0) + return sims.sum(dim=-1) / (N - 1) + + def test_basic_correctness(self): + B, N, D = 4, 4, 256 + emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16) + + ref = self._reference(emb) + out = fused_diversity_penalty(emb) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + def test_two_samples(self): + """With n=2, diversity = dot(a, b) for each.""" + B, D = 3, 128 + emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32) + + ref = self._reference(emb) + out = fused_diversity_penalty(emb) + + torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4) + + def test_identical_samples(self): + """All identical samples should give max diversity.""" + B, N, D = 2, 4, 64 + vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32) + emb = vec.expand(B, N, D).contiguous() + + out = fused_diversity_penalty(emb) + # Should be ||vec||^2 for each (self-excluded mean of identical dot products) + expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N) + torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4) + + def test_large(self): + """Test with realistic EBFT dimensions.""" + B, N, D = 1, 4, 4608 # multi-layer features + emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16) + + ref = self._reference(emb) + out = fused_diversity_penalty(emb) + + torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2) + + def test_single_sample_returns_zeros(self): + """N=1 should return zeros (no pairs), not garbage from uninitialized memory.""" + B, D = 3, 128 + emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32) + out = fused_diversity_penalty(emb) + assert (out == 0).all(), "N=1 diversity should be exactly zero" diff --git a/tests/test_ebft_strided_structured.py b/tests/test_ebft_strided_structured.py new file mode 100644 index 000000000..e4ad946a0 --- /dev/null +++ b/tests/test_ebft_strided_structured.py @@ -0,0 +1,363 @@ +"""Tests for the EBFT strided structured dataset transform and data loading.""" + +import pytest +from datasets import Dataset +from tokenizers import Tokenizer, models, pre_tokenizers +from transformers import PreTrainedTokenizerFast + +from axolotl.prompt_strategies.ebft import load as load_ebft +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def tokenizer(): + """Create a simple word-level tokenizer — no network access needed.""" + # Build a tiny vocab covering common test words + vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3} + words = ( + "what is 2 + the answer 4 hello world goodbye bye hi short prompt " + "x write code print test some string metadata noise ok good python " + "sampling abc 123 def solve return this that" + ).split() + for w in words: + if w not in vocab: + vocab[w] = len(vocab) + + backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]")) + backend.pre_tokenizer = pre_tokenizers.Whitespace() + + tok = PreTrainedTokenizerFast( + tokenizer_object=backend, + bos_token="[BOS]", + eos_token="[EOS]", + pad_token="[PAD]", + unk_token="[UNK]", + ) + return tok + + +@pytest.fixture +def cfg(): + return DictDefault({"sequence_len": 64}) + + +@pytest.fixture +def transform_fn_and_kwargs(cfg): + result = load_ebft("ebft_strided_structured.transform", cfg) + assert result is not None, "Failed to load ebft_strided_structured transform" + transform_fn, map_kwargs = result + return transform_fn, map_kwargs + + +class TestEBFTStridedStructuredTransform: + """Tests for the dataset transform function itself.""" + + def test_transform_loads(self, transform_fn_and_kwargs): + transform_fn, map_kwargs = transform_fn_and_kwargs + assert callable(transform_fn) + assert "remove_columns" in map_kwargs + + def test_remove_columns_sentinel(self, transform_fn_and_kwargs): + """Transform should signal removal of all original columns.""" + _, map_kwargs = transform_fn_and_kwargs + assert map_kwargs["remove_columns"] == "__all__" + + def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer): + """Prompt tokens get labels=-100, completion tokens get real labels.""" + transform_fn, _ = transform_fn_and_kwargs + example = {"input": "what is 2 + 2", "output": "the answer is 4"} + result = transform_fn(example, tokenizer=tokenizer) + + assert "input_ids" in result + assert "labels" in result + assert "attention_mask" in result + assert "prompt_length" in result + + prompt_length = result["prompt_length"] + labels = result["labels"] + seq_len = len(result["input_ids"]) + + assert seq_len == 64, "Should be padded to sequence_len" + assert len(labels) == seq_len + assert prompt_length > 0 + + # Prompt tokens should be masked + for i in range(prompt_length): + assert labels[i] == -100, f"Prompt token at {i} should be -100" + + # At least one completion token should have a real label + completion_labels = [lab for lab in labels[prompt_length:] if lab != -100] + assert len(completion_labels) > 0, "Should have non-masked completion tokens" + + def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer): + """prompt_length should be the exact boundary between -100 and real labels.""" + transform_fn, _ = transform_fn_and_kwargs + example = {"input": "hello world", "output": "goodbye world"} + result = transform_fn(example, tokenizer=tokenizer) + + prompt_length = result["prompt_length"] + labels = result["labels"] + + assert labels[prompt_length - 1] == -100, "Last prompt token should be masked" + assert labels[prompt_length] != -100, ( + "First completion token should not be masked" + ) + + def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer): + """Padding tokens should have labels=-100 and attention_mask=0.""" + transform_fn, _ = transform_fn_and_kwargs + example = {"input": "hi", "output": "bye"} + result = transform_fn(example, tokenizer=tokenizer) + + attention_mask = result["attention_mask"] + labels = result["labels"] + + real_len = sum(attention_mask) + assert real_len < 64, "Short example should have padding" + + for i in range(real_len, 64): + assert attention_mask[i] == 0, ( + f"Pad position {i} should have attention_mask=0" + ) + assert labels[i] == -100, f"Pad position {i} should have labels=-100" + + def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer): + """Long completions should be truncated to fit sequence_len.""" + transform_fn, _ = transform_fn_and_kwargs + example = { + "input": "short prompt", + "output": "x " * 500, + } + result = transform_fn(example, tokenizer=tokenizer) + assert len(result["input_ids"]) == 64 + + def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer): + """Transform should handle different field name conventions.""" + transform_fn, _ = transform_fn_and_kwargs + + result = transform_fn( + {"prompt": "what", "completion": "this"}, tokenizer=tokenizer + ) + assert result["prompt_length"] > 0 + + result = transform_fn( + {"question": "what", "answer": "this"}, tokenizer=tokenizer + ) + assert result["prompt_length"] > 0 + + def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs): + """Without tokenizer, should return a prompt string.""" + transform_fn, _ = transform_fn_and_kwargs + result = transform_fn({"input": "hello", "output": "world"}) + assert "prompt" in result + assert result["prompt"] == "hello" + + +class TestEBFTColumnRemoval: + """Tests for the __all__ column removal logic in the RL data path.""" + + def _filter_remove_columns(self, map_kwargs, dataset): + """Reproduce the filtering logic from rl.py _load_split.""" + if "remove_columns" in map_kwargs: + ds_columns = dataset.column_names + if map_kwargs["remove_columns"] == "__all__": + map_kwargs["remove_columns"] = list(ds_columns) + else: + map_kwargs["remove_columns"] = [ + c for c in map_kwargs["remove_columns"] if c in ds_columns + ] + return map_kwargs + + def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer): + """After mapping, only tokenized columns should remain.""" + transform_fn, map_kwargs = transform_fn_and_kwargs + map_kwargs = dict(map_kwargs) # copy + + ds = Dataset.from_list( + [ + {"input": "what is 2 + 2", "output": "4", "extra_field": "noise"}, + ] + ) + + map_kwargs = self._filter_remove_columns(map_kwargs, ds) + assert "input" in map_kwargs["remove_columns"] + assert "output" in map_kwargs["remove_columns"] + assert "extra_field" in map_kwargs["remove_columns"] + + from functools import partial + + mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs) + assert "input_ids" in mapped.column_names + assert "labels" in mapped.column_names + assert "prompt_length" in mapped.column_names + assert "input" not in mapped.column_names + assert "output" not in mapped.column_names + assert "extra_field" not in mapped.column_names + + def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer): + """Datasets with many extra metadata columns should all be cleaned up.""" + transform_fn, map_kwargs = transform_fn_and_kwargs + map_kwargs = dict(map_kwargs) + + ds = Dataset.from_list( + [ + { + "input": "write hello world", + "output": "print hello", + "id": "abc 123", + "domain": "python", + "generation_algorithm": "sampling", + "llm_judgement": "good", + "unit_tests": "test", + "tests_execution_status": "ok", + "average_test_score": 0.95, + }, + ] + ) + + map_kwargs = self._filter_remove_columns(map_kwargs, ds) + assert len(map_kwargs["remove_columns"]) == 9 + + from functools import partial + + mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs) + + expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"} + assert set(mapped.column_names) == expected_columns + + def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer): + """No string-typed columns should remain (would crash the DataLoader).""" + transform_fn, map_kwargs = transform_fn_and_kwargs + map_kwargs = dict(map_kwargs) + + ds = Dataset.from_list( + [ + {"input": "test", "output": "test", "notes": "some string metadata"}, + ] + ) + + map_kwargs = self._filter_remove_columns(map_kwargs, ds) + + from functools import partial + + mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs) + + for col in mapped.column_names: + feat = mapped.features[col] + assert str(feat) != "string", ( + f"Column '{col}' is still a string — would crash DataLoader" + ) + + def test_filter_preserves_explicit_list(self): + """When remove_columns is an explicit list, only existing columns are kept.""" + ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}]) + map_kwargs = {"remove_columns": ["a", "b", "missing_col"]} + + ds_columns = ds.column_names + map_kwargs["remove_columns"] = [ + c for c in map_kwargs["remove_columns"] if c in ds_columns + ] + + assert map_kwargs["remove_columns"] == ["a", "b"] + assert "missing_col" not in map_kwargs["remove_columns"] + + +class TestMultiTurnSeparators: + """Verify multi-turn transforms and trainer-side GT reconstruction.""" + + def test_multiturn_transform_splits_turns(self): + """Transform should store first turn as GT and remaining turns separately.""" + from axolotl.prompt_strategies.ebft import load as load_ebft + from axolotl.utils.dict import DictDefault + + cfg = DictDefault({"sequence_len": 512}) + fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg) + out = fn( + { + "messages": [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + } + ) + # ground_truth is only the first assistant turn + assert out["ground_truth"] == "A1" + # remaining_turns carries the rest for trainer-side reconstruction + assert out["remaining_turns"] == [ + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + + def test_multiturn_gt_reconstruction_via_chat_template(self): + """Trainer-side GT reconstruction should insert role markers between turns. + + This tests the logic from trainer.py:284-299 that reconstructs multi-turn + GT using apply_chat_template, ensuring assistant turns are separated by + role markers rather than concatenated as raw text. + """ + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True + ) + + # Simulate the transform output + prompt_msgs = [{"role": "user", "content": "Q1"}] + gt = "A1" + remaining_turns = [ + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + + # --- Reproduce the trainer-side reconstruction (trainer.py:284-299) --- + prompt_text = tokenizer.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}] + gt_conv.extend(remaining_turns) + full_gt_text = tokenizer.apply_chat_template( + gt_conv, tokenize=False, add_generation_prompt=False + ) + + # The full GT text should contain both assistant turns with role markers + assert "A1" in full_gt_text + assert "A2" in full_gt_text + # Raw concatenation "A1A2" should NOT appear — role markers separate them + assert "A1A2" not in full_gt_text, ( + "GT reconstruction should have role markers between turns, not raw concatenation" + ) + # The user turn Q2 should appear between A1 and A2 + a1_pos = full_gt_text.index("A1") + a2_pos = full_gt_text.index("A2") + q2_pos = full_gt_text.index("Q2") + assert a1_pos < q2_pos < a2_pos, ( + "Turn order should be A1 -> Q2 -> A2 in rendered GT" + ) + # The GT should start with the prompt + assert full_gt_text.startswith(prompt_text), ( + "Full GT should start with the rendered prompt" + ) + + def test_multiturn_gt_reconstruction_fallback_single_turn(self): + """Single-turn prompts in a multi-turn dataset should use raw concatenation.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True + ) + + prompt_msgs = [{"role": "user", "content": "Q1"}] + gt = "A1" + # remaining_turns would be [] for single-turn prompts + + prompt_text = tokenizer.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + + # With empty remaining_turns, trainer falls through to raw concat + # (trainer.py:302: gt_texts.append(prompt_text + gt)) + gt_text = prompt_text + gt + assert gt_text.endswith("A1") + assert prompt_text in gt_text diff --git a/tests/test_http_weight_sync.py b/tests/test_http_weight_sync.py new file mode 100644 index 000000000..97c5ece71 --- /dev/null +++ b/tests/test_http_weight_sync.py @@ -0,0 +1,158 @@ +"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32). + +Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle +the three-stage weight transfer: trainer → serve endpoint → vLLM worker. +""" + +import pytest +import torch + +from axolotl.utils.weight_serde import ( + decode_from_http, + decode_from_ipc, + encode_for_http, + encode_for_ipc, +) + +# --------------------------------------------------------------------------- +# Stage 1: trainer → serve endpoint (HTTP with base64) +# --------------------------------------------------------------------------- + + +class TestHttpEncodeRoundTrip: + """Test encode_for_http / decode_from_http.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_round_trip_dtype(self, dtype): + original = torch.randn(32, 64, dtype=dtype) + entry = encode_for_http("layer.weight", original) + name, decoded = decode_from_http(entry) + + assert name == "layer.weight" + assert decoded.dtype == dtype + assert decoded.shape == original.shape + if dtype == torch.bfloat16: + # bf16→fp16→bf16 loses some precision + torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2) + else: + torch.testing.assert_close(decoded, original, atol=0, rtol=0) + + def test_bfloat16_wire_format_is_fp16(self): + """bf16 tensors should be sent as fp16 on the wire.""" + import base64 + + original = torch.randn(8, 16, dtype=torch.bfloat16) + entry = encode_for_http("w", original) + raw = base64.b64decode(entry["data"]) + # 8*16 elements * 2 bytes/elem (fp16) = 256 bytes + assert len(raw) == 8 * 16 * 2 + # dtype field should preserve original dtype for reconstruction + assert entry["dtype"] == "torch.bfloat16" + + def test_multidimensional_shapes(self): + for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]: + original = torch.randn(*shape, dtype=torch.bfloat16) + entry = encode_for_http("w", original) + _, decoded = decode_from_http(entry) + assert decoded.shape == original.shape + assert decoded.dtype == torch.bfloat16 + + +# --------------------------------------------------------------------------- +# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes) +# --------------------------------------------------------------------------- + + +class TestIpcEncodeRoundTrip: + """Test encode_for_ipc / decode_from_ipc.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_round_trip_dtype(self, dtype): + original = torch.randn(32, 64, dtype=dtype) + entry = encode_for_ipc("layer.weight", original) + name, decoded = decode_from_ipc(entry) + + assert name == "layer.weight" + assert decoded.dtype == dtype + assert decoded.shape == original.shape + if dtype == torch.bfloat16: + torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2) + else: + torch.testing.assert_close(decoded, original, atol=0, rtol=0) + + def test_bfloat16_ipc_wire_is_fp16(self): + """bf16 tensors should be serialized as fp16 bytes in IPC.""" + original = torch.randn(4, 8, dtype=torch.bfloat16) + entry = encode_for_ipc("w", original) + assert entry["dtype"] == "float16" + assert entry["target_dtype"] == "bfloat16" + assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes + + def test_fp32_has_no_target_dtype_mismatch(self): + original = torch.randn(4, 8, dtype=torch.float32) + entry = encode_for_ipc("w", original) + assert entry["dtype"] == "float32" + assert entry["target_dtype"] == "float32" + + def test_worker_handles_missing_target_dtype(self): + """Backward compat: older serve code may not send target_dtype.""" + entry = { + "name": "w", + "data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(), + "dtype": "float32", + "shape": [4, 8], + # no target_dtype key + } + name, decoded = decode_from_ipc(entry) + assert decoded.dtype == torch.float32 + assert decoded.shape == (4, 8) + + +# --------------------------------------------------------------------------- +# Full pipeline: trainer → serve → worker +# --------------------------------------------------------------------------- + + +class TestFullPipelineRoundTrip: + """End-to-end: trainer → serve → worker.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_three_stage_round_trip(self, dtype): + """Tensor survives trainer→serve→worker with correct dtype and values.""" + original = torch.randn(16, 32, dtype=dtype) + + # Stage 1: trainer encodes for HTTP + http_entry = encode_for_http("model.layers.0.weight", original) + + # Stage 2: serve decodes HTTP, re-encodes for IPC + name, at_serve = decode_from_http(http_entry) + ipc_entry = encode_for_ipc(name, at_serve) + + # Stage 3: worker decodes IPC + _, at_worker = decode_from_ipc(ipc_entry) + + assert at_worker.dtype == dtype + assert at_worker.shape == original.shape + if dtype == torch.bfloat16: + # Two bf16→fp16→bf16 hops compound precision loss slightly + torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2) + else: + torch.testing.assert_close(at_worker, original, atol=0, rtol=0) + + def test_bfloat16_precision_loss_is_bounded(self): + """bf16→fp16→bf16 round-trip error should be small.""" + original = torch.randn(256, 256, dtype=torch.bfloat16) + http_entry = encode_for_http("w", original) + _, at_serve = decode_from_http(http_entry) + ipc_entry = encode_for_ipc("w", at_serve) + _, at_worker = decode_from_ipc(ipc_entry) + + max_err = (at_worker.float() - original.float()).abs().max().item() + # bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded + assert max_err < 0.05, f"Max error {max_err} exceeds bound" + + def test_bfloat16_numpy_would_crash_without_fix(self): + """Verify that calling .numpy() on bf16 raises, confirming the fix is needed.""" + t = torch.randn(4, 4, dtype=torch.bfloat16) + with pytest.raises((RuntimeError, TypeError)): + t.numpy()