EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
@@ -38,18 +38,14 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
# Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
|
||||
# We default to axolotl's serve module instead of TRL's because TRL's sends
|
||||
# truncate_prompt_tokens which is unsupported in vLLM 0.17+.
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -79,6 +75,12 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
cli_enforce_eager = cli_args.get("enforce_eager")
|
||||
cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
|
||||
raw_enforce_eager = (
|
||||
cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
|
||||
)
|
||||
enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
|
||||
base_kwargs = dict(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
@@ -89,6 +91,7 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
@@ -98,6 +101,10 @@ def do_vllm_serve(
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
# Disable native LoRA in vLLM if not using vllm_lora_sync
|
||||
# (merged weight sync via batch_update doesn't need vLLM LoRA mode)
|
||||
if not getattr(cfg.trl, "vllm_lora_sync", False):
|
||||
lora_kwargs["enable_lora"] = False
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
|
||||
@@ -118,7 +118,7 @@ def load_preference_datasets(
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
@@ -78,6 +78,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
|
||||
trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -179,6 +184,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
|
||||
training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -211,7 +223,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .ebft.strided import AxolotlStridedEBFTTrainer
|
||||
from .ebft.trainer import AxolotlEBFTTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
|
||||
213
src/axolotl/core/trainers/ebft/__init__.py
Normal file
213
src/axolotl/core/trainers/ebft/__init__.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""EBFT (Energy-Based Fine-Tuning) Strategy for training
|
||||
|
||||
Two modes:
|
||||
- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
|
||||
- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from axolotl.core.trainers.ebft.args import (
|
||||
AxolotlAsyncEBFTConfig,
|
||||
AxolotlEBFTConfig,
|
||||
AxolotlStridedEBFTConfig,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _get_ebft_mode(cfg: DictDefault) -> str:
|
||||
"""Determine EBFT mode from config."""
|
||||
if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
|
||||
return cfg.ebft.mode
|
||||
return "structured"
|
||||
|
||||
|
||||
class EBFTStrategy:
|
||||
"""Strategy for EBFT training — dispatches between structured and strided modes."""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
|
||||
|
||||
return AxolotlStridedEBFTTrainer
|
||||
|
||||
# Structured mode: async or sync
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
if use_async:
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
|
||||
|
||||
return AxolotlAsyncEBFTTrainer
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
|
||||
|
||||
return AxolotlEBFTTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return AxolotlStridedEBFTConfig
|
||||
|
||||
# Structured mode: async or sync config
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
if use_async:
|
||||
return AxolotlAsyncEBFTConfig
|
||||
return AxolotlEBFTConfig
|
||||
|
||||
@classmethod
|
||||
def is_strided(cls, cfg: DictDefault) -> bool:
|
||||
return _get_ebft_mode(cfg) == "strided"
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
"""Map axolotl YAML config fields to training args kwargs."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
mode = _get_ebft_mode(cfg)
|
||||
|
||||
# Common EBFT fields
|
||||
ebft = cfg.ebft
|
||||
if ebft:
|
||||
if ebft.feature_layers is not None:
|
||||
kwargs["ebft_feature_layers"] = ebft.feature_layers
|
||||
if ebft.embed_method is not None:
|
||||
kwargs["ebft_embed_method"] = ebft.embed_method
|
||||
if ebft.use_whitening is not None:
|
||||
kwargs["ebft_use_whitening"] = ebft.use_whitening
|
||||
if ebft.alignment_coef is not None:
|
||||
kwargs["ebft_alignment_coef"] = ebft.alignment_coef
|
||||
if ebft.diversity_coef is not None:
|
||||
kwargs["ebft_diversity_coef"] = ebft.diversity_coef
|
||||
if ebft.ce_coef is not None:
|
||||
kwargs["ebft_ce_coef"] = ebft.ce_coef
|
||||
if getattr(ebft, "adaptive_max_tokens", None) is not None:
|
||||
kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
|
||||
if getattr(ebft, "gt_length_multiplier", None) is not None:
|
||||
kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
|
||||
|
||||
if mode == "strided":
|
||||
# Strided-specific fields
|
||||
if ebft:
|
||||
if ebft.stride is not None:
|
||||
kwargs["ebft_stride"] = ebft.stride
|
||||
if ebft.context_length is not None:
|
||||
kwargs["ebft_context_length"] = ebft.context_length
|
||||
if ebft.generate_max_len is not None:
|
||||
kwargs["ebft_generate_max_len"] = ebft.generate_max_len
|
||||
if ebft.n_samples_per_prompt is not None:
|
||||
kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
|
||||
if ebft.temperature is not None:
|
||||
kwargs["ebft_temperature"] = ebft.temperature
|
||||
if ebft.top_p is not None:
|
||||
kwargs["ebft_top_p"] = ebft.top_p
|
||||
if ebft.rl_coef is not None:
|
||||
kwargs["ebft_rl_coef"] = ebft.rl_coef
|
||||
if ebft.advantage_estimator is not None:
|
||||
kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
|
||||
if ebft.min_completion_prefix is not None:
|
||||
kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
|
||||
else:
|
||||
# Structured mode: map TRL config fields
|
||||
trl = cfg.trl
|
||||
if trl:
|
||||
if trl.use_vllm:
|
||||
kwargs["use_vllm"] = trl.use_vllm
|
||||
if trl.vllm_mode:
|
||||
kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
|
||||
vllm_cfg = cfg.vllm
|
||||
if vllm_cfg:
|
||||
kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
kwargs["vllm_server_host"] = trl.vllm_server_host or (
|
||||
trl.vllm.host if trl.vllm else None
|
||||
)
|
||||
kwargs["vllm_server_port"] = trl.vllm_server_port or (
|
||||
trl.vllm.port if trl.vllm else None
|
||||
)
|
||||
if trl.vllm_server_timeout:
|
||||
kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
|
||||
if trl.num_generations:
|
||||
kwargs["num_generations"] = trl.num_generations
|
||||
if trl.max_completion_length is not None:
|
||||
kwargs["max_completion_length"] = trl.max_completion_length
|
||||
if trl.temperature is not None:
|
||||
kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
kwargs["min_p"] = trl.min_p
|
||||
if trl.num_iterations is not None:
|
||||
kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
kwargs["epsilon"] = trl.epsilon
|
||||
if trl.epsilon_high is not None:
|
||||
kwargs["epsilon_high"] = trl.epsilon_high
|
||||
if trl.scale_rewards is not None:
|
||||
kwargs["scale_rewards"] = trl.scale_rewards
|
||||
if trl.loss_type is not None:
|
||||
kwargs["loss_type"] = trl.loss_type
|
||||
if trl.mask_truncated_completions is not None:
|
||||
kwargs["mask_truncated_completions"] = (
|
||||
trl.mask_truncated_completions
|
||||
)
|
||||
if trl.log_completions is not None:
|
||||
kwargs["log_completions"] = trl.log_completions
|
||||
if trl.num_completions_to_print is not None:
|
||||
kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
if trl.sync_ref_model:
|
||||
kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
if trl.repetition_penalty is not None:
|
||||
kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
if trl.generation_kwargs is not None:
|
||||
kwargs["generation_kwargs"] = trl.generation_kwargs
|
||||
if trl.chat_template_kwargs is not None:
|
||||
kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
|
||||
|
||||
# Async prefetch fields (only pass when enabled — sync config doesn't have these)
|
||||
if getattr(trl, "async_prefetch", False):
|
||||
kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "vllm_lora_sync", False):
|
||||
kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"max_prompt_length",
|
||||
"include_tokens_per_second",
|
||||
"beta",
|
||||
]
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"include_tokens_per_second",
|
||||
"max_prompt_length",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs):
|
||||
return None
|
||||
133
src/axolotl/core/trainers/ebft/args.py
Normal file
133
src/axolotl/core/trainers/ebft/args.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
EBFT-specific training arguments.
|
||||
|
||||
Two config classes:
|
||||
- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
|
||||
- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
# -- Shared EBFT fields as a mixin --
|
||||
@dataclass
|
||||
class EBFTFieldsMixin:
|
||||
"""Common fields shared between structured and strided EBFT configs."""
|
||||
|
||||
ebft_feature_layers: list[float] = field(
|
||||
default_factory=lambda: [0.25, 0.5, 0.75],
|
||||
metadata={"help": "Fractional layer depths for feature extraction"},
|
||||
)
|
||||
ebft_embed_method: str = field(
|
||||
default="last_token",
|
||||
metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
|
||||
)
|
||||
ebft_use_whitening: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Apply SVD whitening to feature embeddings"},
|
||||
)
|
||||
ebft_alignment_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
|
||||
)
|
||||
ebft_diversity_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for diversity penalty"},
|
||||
)
|
||||
ebft_ce_coef: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
|
||||
)
|
||||
ebft_adaptive_max_tokens: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
|
||||
)
|
||||
ebft_gt_length_multiplier: float = field(
|
||||
default=1.5,
|
||||
metadata={
|
||||
"help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
|
||||
@dataclass
|
||||
class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
|
||||
"""EBFT config for structured QA data — extends GRPOConfig."""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Async structured mode: extends FastAsyncGRPOConfig --
|
||||
@dataclass
|
||||
class AxolotlAsyncEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
|
||||
):
|
||||
"""EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
|
||||
|
||||
Includes all async fields: async_prefetch, vllm_lora_sync,
|
||||
skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
|
||||
"""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Strided mode: extends TrainingArguments for unstructured text --
|
||||
@dataclass
|
||||
class AxolotlStridedEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
|
||||
):
|
||||
"""EBFT config for unstructured text with strided block-parallel generation."""
|
||||
|
||||
ebft_stride: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Stride between anchor points (in tokens)"},
|
||||
)
|
||||
ebft_context_length: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Context window size for each block"},
|
||||
)
|
||||
ebft_generate_max_len: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of tokens to generate per block"},
|
||||
)
|
||||
ebft_n_samples_per_prompt: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of independent rollouts per document"},
|
||||
)
|
||||
ebft_temperature: float = field(
|
||||
default=0.6,
|
||||
metadata={"help": "Sampling temperature for strided generation"},
|
||||
)
|
||||
ebft_top_p: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Top-p nucleus sampling threshold"},
|
||||
)
|
||||
ebft_rl_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "RL policy gradient loss coefficient"},
|
||||
)
|
||||
ebft_advantage_estimator: str = field(
|
||||
default="rloo",
|
||||
metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
|
||||
)
|
||||
ebft_min_completion_prefix: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Minimum tokens into completion before placing anchors"},
|
||||
)
|
||||
308
src/axolotl/core/trainers/ebft/kernels.py
Normal file
308
src/axolotl/core/trainers/ebft/kernels.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Fused Triton kernels for strided EBFT.
|
||||
|
||||
These kernels eliminate intermediate tensor materializations that dominate
|
||||
the elementwise/fill category (~40% of CUDA time in profiling).
|
||||
|
||||
Kernels:
|
||||
1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
|
||||
2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
|
||||
3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Fused log_softmax + gather (selective log softmax)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
|
||||
# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
|
||||
# This avoids materializing the full (B, S, V) log_softmax output.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_log_softmax_gather_kernel(
|
||||
logits_ptr, # (B*S, V) row-major
|
||||
labels_ptr, # (B*S,) int64
|
||||
output_ptr, # (B*S,) float32
|
||||
V: tl.constexpr, # vocab size
|
||||
BLOCK_V: tl.constexpr, # tile width over vocab
|
||||
):
|
||||
"""Compute log_softmax(logits)[label] for each row without materializing full output."""
|
||||
row = tl.program_id(0)
|
||||
|
||||
logits_row_ptr = logits_ptr + row * V
|
||||
label = tl.load(labels_ptr + row)
|
||||
|
||||
# Pass 1: find max for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
max_val = tl.maximum(max_val, tl.max(vals, axis=0))
|
||||
|
||||
# Pass 2: compute sum(exp(x - max))
|
||||
sum_exp = 0.0
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
|
||||
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Gather: log_softmax[label] = logits[label] - max - log_sum_exp
|
||||
target_logit = tl.load(logits_row_ptr + label)
|
||||
result = target_logit - max_val - log_sum_exp
|
||||
|
||||
tl.store(output_ptr + row, result)
|
||||
|
||||
|
||||
def fused_log_softmax_gather(
|
||||
logits: torch.Tensor, labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
|
||||
|
||||
Args:
|
||||
logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
|
||||
labels: (B, S) or (B*S,) int64 tensor of target indices
|
||||
|
||||
Returns:
|
||||
(B, S) or (B*S,) float32 tensor of selected log probabilities
|
||||
"""
|
||||
orig_shape = logits.shape[:-1]
|
||||
V = logits.shape[-1]
|
||||
logits_2d = logits.reshape(-1, V).contiguous()
|
||||
labels_1d = labels.reshape(-1).contiguous()
|
||||
n_rows = logits_2d.shape[0]
|
||||
|
||||
output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
|
||||
|
||||
# Choose BLOCK_V: must be power of 2, large enough for good occupancy
|
||||
BLOCK_V = min(triton.next_power_of_2(V), 65536)
|
||||
|
||||
_fused_log_softmax_gather_kernel[(n_rows,)](
|
||||
logits_2d,
|
||||
labels_1d,
|
||||
output,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Fused masked REINFORCE loss reduction
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: (-logp * adv * mask).sum() / mask.sum()
|
||||
# We do the masked multiply-accumulate in one kernel, returning (sum, count).
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_reinforce_loss_kernel(
|
||||
logps_ptr, # (N,) float32 per-token log probs
|
||||
advs_ptr, # (N,) float32 advantages
|
||||
mask_ptr, # (N,) bool action mask
|
||||
partial_sum_ptr, # (n_blocks,) partial sums
|
||||
partial_cnt_ptr, # (n_blocks,) partial counts
|
||||
N: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
block_id = tl.program_id(0)
|
||||
offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
valid = offsets < N
|
||||
|
||||
logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
|
||||
advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
|
||||
m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
|
||||
|
||||
# -logp * advantage * mask
|
||||
loss = -logps * advs * m
|
||||
block_sum = tl.sum(loss, axis=0)
|
||||
block_cnt = tl.sum(m, axis=0)
|
||||
|
||||
tl.store(partial_sum_ptr + block_id, block_sum)
|
||||
tl.store(partial_cnt_ptr + block_id, block_cnt)
|
||||
|
||||
|
||||
def fused_reinforce_loss(
|
||||
per_token_logps: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
|
||||
|
||||
All inputs should be flat or will be flattened. Returns scalar loss tensor.
|
||||
"""
|
||||
logps_flat = per_token_logps.reshape(-1).contiguous()
|
||||
advs_flat = advantages.reshape(-1).contiguous()
|
||||
mask_flat = action_mask.reshape(-1).contiguous()
|
||||
N = logps_flat.shape[0]
|
||||
|
||||
BLOCK_N = 1024
|
||||
n_blocks = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
|
||||
_fused_reinforce_loss_kernel[(n_blocks,)](
|
||||
logps_flat,
|
||||
advs_flat,
|
||||
mask_flat,
|
||||
partial_sum,
|
||||
partial_cnt,
|
||||
N=N,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
total_sum = partial_sum.sum()
|
||||
total_cnt = partial_cnt.sum().clamp(min=1)
|
||||
return total_sum / total_cnt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Fused cosine similarity (batched, for EBFT rewards)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
|
||||
# we fuse the dot product, norm computation, and division into one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_cosine_sim_kernel(
|
||||
a_ptr, # (N, D) first set of vectors
|
||||
b_ptr, # (N, D) second set of vectors
|
||||
out_ptr, # (N,) cosine similarities
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
a_row_ptr = a_ptr + row * D
|
||||
b_row_ptr = b_ptr + row * D
|
||||
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
mask = d_offsets < D
|
||||
a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
norm_a += tl.sum(a_vals * a_vals, axis=0)
|
||||
norm_b += tl.sum(b_vals * b_vals, axis=0)
|
||||
|
||||
denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
|
||||
denom = tl.where(denom > 1e-8, denom, 1e-8)
|
||||
result = dot / denom
|
||||
|
||||
tl.store(out_ptr + row, result)
|
||||
|
||||
|
||||
def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cosine similarity along the last dimension.
|
||||
|
||||
Args:
|
||||
a, b: (..., D) tensors of the same shape
|
||||
|
||||
Returns:
|
||||
(...,) tensor of cosine similarities
|
||||
"""
|
||||
orig_shape = a.shape[:-1]
|
||||
D = a.shape[-1]
|
||||
a_2d = a.reshape(-1, D).contiguous()
|
||||
b_2d = b.reshape(-1, D).contiguous()
|
||||
N = a_2d.shape[0]
|
||||
|
||||
output = torch.empty(N, device=a.device, dtype=torch.float32)
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_cosine_sim_kernel[(N,)](
|
||||
a_2d,
|
||||
b_2d,
|
||||
output,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Fused pairwise diversity penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
|
||||
# We compute the pairwise dot products and exclusion in one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_diversity_kernel(
|
||||
emb_ptr, # (B, N, D) embeddings, row-major
|
||||
out_ptr, # (B, N) diversity penalties
|
||||
N: tl.constexpr, # n_samples
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
"""For each (b, i), compute mean dot product to all j != i."""
|
||||
b = tl.program_id(0)
|
||||
i = tl.program_id(1)
|
||||
|
||||
# Pointer to emb[b, i, :]
|
||||
emb_bi_ptr = emb_ptr + (b * N + i) * D
|
||||
|
||||
total_sim = 0.0
|
||||
for j in range(N):
|
||||
emb_bj_ptr = emb_ptr + (b * N + j) * D
|
||||
|
||||
dot = 0.0
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offsets < D
|
||||
a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
|
||||
# Exclude self-similarity (j == i)
|
||||
is_other = j != i
|
||||
total_sim += dot * is_other
|
||||
|
||||
result = total_sim / (N - 1)
|
||||
tl.store(out_ptr + b * N + i, result)
|
||||
|
||||
|
||||
def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute mean pairwise dot product (excluding self) per sample.
|
||||
|
||||
Args:
|
||||
embeddings: (B, N, D) tensor where N is n_samples
|
||||
|
||||
Returns:
|
||||
(B, N) tensor of diversity penalties
|
||||
"""
|
||||
B, N, D = embeddings.shape
|
||||
embeddings = embeddings.contiguous()
|
||||
output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
|
||||
if N <= 1:
|
||||
return output # diversity is 0 when there's only one sample
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_diversity_kernel[(B, N)](
|
||||
embeddings,
|
||||
output,
|
||||
N=N,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output
|
||||
264
src/axolotl/core/trainers/ebft/rewards.py
Normal file
264
src/axolotl/core/trainers/ebft/rewards.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
|
||||
|
||||
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_hidden_states(
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_indices: list[int],
|
||||
batch_size: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through model, extracting and concatenating hidden states
|
||||
at specified layer indices.
|
||||
|
||||
Args:
|
||||
model: The frozen feature network
|
||||
input_ids: (B, S) token ids
|
||||
attention_mask: (B, S) attention mask
|
||||
layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
|
||||
batch_size: If set, process in chunks to reduce peak memory
|
||||
|
||||
Returns:
|
||||
Concatenated hidden states: (B, S, num_layers * H)
|
||||
"""
|
||||
if batch_size is None:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Use the inner transformer body (skips lm_head) when available.
|
||||
# This avoids the expensive hidden_dim × vocab_size matmul whose
|
||||
# output (logits) is never used — only hidden_states are needed.
|
||||
body = getattr(model, "model", None)
|
||||
if body is not None and hasattr(body, "forward"):
|
||||
forward_model = body
|
||||
else:
|
||||
forward_model = model
|
||||
|
||||
all_features = []
|
||||
for i in range(0, input_ids.shape[0], batch_size):
|
||||
chunk_ids = input_ids[i : i + batch_size]
|
||||
chunk_mask = attention_mask[i : i + batch_size]
|
||||
|
||||
outputs = forward_model(
|
||||
chunk_ids,
|
||||
attention_mask=chunk_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
|
||||
# index 0 is the embedding layer output
|
||||
hidden_states = outputs.hidden_states
|
||||
chunk_features = []
|
||||
for idx in layer_indices:
|
||||
chunk_features.append(hidden_states[idx])
|
||||
|
||||
# Concatenate across feature dimension: (B, S, num_layers * H)
|
||||
all_features.append(torch.cat(chunk_features, dim=-1))
|
||||
|
||||
return torch.cat(all_features, dim=0)
|
||||
|
||||
|
||||
def apply_embed_method(
|
||||
hidden_states: torch.Tensor,
|
||||
method: str,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
prompt_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pool per-token hidden states into per-sequence embeddings.
|
||||
|
||||
Args:
|
||||
hidden_states: (B, S, D) concatenated hidden states
|
||||
method: One of "last_token", "mean_pooling", "completion_mean", "concat"
|
||||
attention_mask: (B, S) mask for mean pooling
|
||||
prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
|
||||
|
||||
Returns:
|
||||
Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
|
||||
(B, 3*D) for concat
|
||||
"""
|
||||
if method == "last_token":
|
||||
if attention_mask is not None:
|
||||
# Find last non-padding position per sample
|
||||
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
|
||||
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
|
||||
return hidden_states[:, -1, :]
|
||||
|
||||
if method == "mean_pooling":
|
||||
if attention_mask is not None:
|
||||
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
return hidden_states.mean(dim=1)
|
||||
|
||||
if method == "completion_mean":
|
||||
# Mean pool over completion tokens only (exclude prompt)
|
||||
if prompt_lengths is None:
|
||||
raise ValueError("completion_mean requires prompt_lengths")
|
||||
B, S, _ = hidden_states.shape
|
||||
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
|
||||
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
|
||||
if attention_mask is not None:
|
||||
comp_mask = comp_mask & attention_mask.bool()
|
||||
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
|
||||
if method == "concat":
|
||||
B, S, D = hidden_states.shape
|
||||
if attention_mask is not None:
|
||||
valid_lens = attention_mask.sum(dim=1).long() # (B,)
|
||||
else:
|
||||
valid_lens = torch.full(
|
||||
(B,), S, device=hidden_states.device, dtype=torch.long
|
||||
)
|
||||
# Compute quartile positions relative to valid length per sample
|
||||
# First valid position index for each sample (handles right-padding)
|
||||
q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
|
||||
q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
batch_idx = torch.arange(B, device=hidden_states.device)
|
||||
return torch.cat(
|
||||
[
|
||||
hidden_states[batch_idx, q1],
|
||||
hidden_states[batch_idx, q2],
|
||||
hidden_states[batch_idx, q3],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown embed_method: {method}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_alignment_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
gt_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute alignment reward as cosine similarity between generated
|
||||
and ground-truth feature embeddings.
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated sequence embeddings
|
||||
gt_embedding: (B, D) ground-truth sequence embeddings
|
||||
If num_generations > 1, gt_embedding should be repeated
|
||||
to match gen_embedding's batch dim.
|
||||
|
||||
Returns:
|
||||
Alignment rewards: (B,) cosine similarities in [-1, 1]
|
||||
"""
|
||||
return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diversity_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
num_generations: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute diversity penalty as mean pairwise dot-product similarity
|
||||
between samples from the same prompt (excluding self-similarity).
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
|
||||
num_generations: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Diversity penalties: (B,) mean similarity to other samples from same prompt
|
||||
"""
|
||||
if num_generations <= 1:
|
||||
return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
|
||||
|
||||
num_prompts = gen_embedding.shape[0] // num_generations
|
||||
|
||||
# Reshape to (num_prompts, num_generations, D)
|
||||
reshaped = gen_embedding.view(num_prompts, num_generations, -1)
|
||||
|
||||
# Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
|
||||
sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
|
||||
|
||||
# Zero out self-similarity (diagonal)
|
||||
eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
|
||||
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
|
||||
|
||||
# Mean similarity to other samples: (num_prompts, num_generations)
|
||||
diversity = sims.sum(dim=-1) / (num_generations - 1)
|
||||
|
||||
# Flatten back to (B,)
|
||||
return diversity.view(-1)
|
||||
|
||||
|
||||
def whiten_embeddings_batched(
|
||||
phi: torch.Tensor,
|
||||
phi_gt: torch.Tensor,
|
||||
whiten_tol: float = 1e-5,
|
||||
normalize: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Whiten generated embeddings using SVD, then apply same transform to ground-truth.
|
||||
|
||||
Whitening decorrelates feature dimensions so no single direction dominates
|
||||
the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
|
||||
|
||||
Note: Singular values scale with sqrt(B), so reward magnitudes are
|
||||
batch-size dependent. This is acceptable because B = n_samples_per_prompt
|
||||
which is fixed during training (typically 2-4).
|
||||
|
||||
Args:
|
||||
phi: (B, D) generated embeddings (used to estimate covariance)
|
||||
phi_gt: (B, D) ground-truth embeddings
|
||||
whiten_tol: Tolerance for singular value cutoff
|
||||
normalize: If True, L2-normalize after whitening
|
||||
|
||||
Returns:
|
||||
Whitened (phi, phi_gt) tuple, each (B, D)
|
||||
"""
|
||||
phi_f = phi.float()
|
||||
phi_gt_f = phi_gt.float()
|
||||
|
||||
# Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
|
||||
except torch._C._LinAlgError:
|
||||
# Fallback: add small noise
|
||||
noise = 1e-6 * phi_f.abs().mean()
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(
|
||||
(phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
|
||||
full_matrices=False,
|
||||
)
|
||||
except torch._C._LinAlgError:
|
||||
if normalize:
|
||||
return (
|
||||
F.normalize(phi, p=2, dim=-1),
|
||||
F.normalize(phi_gt, p=2, dim=-1),
|
||||
)
|
||||
return phi, phi_gt
|
||||
|
||||
U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
|
||||
|
||||
# Safe inverse of singular values
|
||||
s_max = S.max()
|
||||
inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
|
||||
|
||||
# W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
|
||||
W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
|
||||
phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
|
||||
phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
|
||||
|
||||
if normalize:
|
||||
phi_w = F.normalize(phi_w, p=2, dim=-1)
|
||||
phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
|
||||
|
||||
return phi_w, phi_gt_w
|
||||
1152
src/axolotl/core/trainers/ebft/strided.py
Normal file
1152
src/axolotl/core/trainers/ebft/strided.py
Normal file
File diff suppressed because it is too large
Load Diff
531
src/axolotl/core/trainers/ebft/trainer.py
Normal file
531
src/axolotl/core/trainers/ebft/trainer.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
|
||||
|
||||
Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
|
||||
the standard GRPO reward function interface.
|
||||
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
|
||||
|
||||
from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
|
||||
from axolotl.core.trainers.ebft.rewards import (
|
||||
apply_embed_method,
|
||||
extract_hidden_states,
|
||||
get_alignment_rewards,
|
||||
get_diversity_rewards,
|
||||
whiten_embeddings_batched,
|
||||
)
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
from accelerate import Accelerator
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EBFTMixin:
|
||||
"""
|
||||
Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
|
||||
|
||||
Provides:
|
||||
- Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
|
||||
- _feature_matching_reward() callable for GRPO reward function interface
|
||||
- _sequential_rollout() for multi-turn conversations
|
||||
"""
|
||||
|
||||
# Type stubs for attributes provided by the composed GRPOTrainer base class.
|
||||
# These are not defined here but accessed via cooperative multiple inheritance.
|
||||
if TYPE_CHECKING:
|
||||
accelerator: Accelerator
|
||||
model: PreTrainedModel
|
||||
args: AxolotlEBFTConfig
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
num_generations: int
|
||||
vllm_generation: VLLMGeneration
|
||||
_metrics: defaultdict
|
||||
|
||||
_tag_names = ["trl", "ebft", "axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | PreTrainedModel,
|
||||
args: AxolotlEBFTConfig | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset
|
||||
| IterableDataset
|
||||
| dict[str, Dataset | IterableDataset]
|
||||
| None = None,
|
||||
processing_class: PreTrainedTokenizerBase | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[
|
||||
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
|
||||
] = (None, None),
|
||||
peft_config: Any | None = None,
|
||||
):
|
||||
# Pass our feature-matching reward function to GRPOTrainer
|
||||
# It will be called with (prompts, completions, **kwargs) where
|
||||
# kwargs includes all extra dataset fields like "ground_truth"
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
model=model,
|
||||
reward_funcs=[self._feature_matching_reward],
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
assert args is not None
|
||||
|
||||
# --- Feature network setup ---
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
# Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
|
||||
self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
|
||||
unwrapped, "disable_adapter"
|
||||
)
|
||||
|
||||
if self._share_feature_weights:
|
||||
# Share weights: use actor's base model with adapters disabled.
|
||||
# Saves a full model copy (~8 GB for 4B model).
|
||||
self.feature_network = None
|
||||
param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
|
||||
LOG.info(
|
||||
f"EBFT feature network shares actor weights (PEFT disable_adapter). "
|
||||
f"Saving ~{param_gb:.1f} GB"
|
||||
)
|
||||
else:
|
||||
LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
|
||||
self.feature_network = copy.deepcopy(unwrapped)
|
||||
for param in self.feature_network.parameters():
|
||||
param.requires_grad = False
|
||||
self.feature_network.eval()
|
||||
|
||||
# Compute layer indices from fractional depths
|
||||
# Handle VLM models where num_hidden_layers is on text_config
|
||||
config = unwrapped.config
|
||||
if hasattr(config, "text_config") and hasattr(
|
||||
config.text_config, "num_hidden_layers"
|
||||
):
|
||||
config = config.text_config
|
||||
num_layers = config.num_hidden_layers
|
||||
self.feature_layer_indices = [
|
||||
int(frac * num_layers) for frac in args.ebft_feature_layers
|
||||
]
|
||||
LOG.info(
|
||||
f"EBFT feature extraction from layers {self.feature_layer_indices} "
|
||||
f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
|
||||
)
|
||||
if args.ebft_adaptive_max_tokens:
|
||||
LOG.info(
|
||||
f"EBFT adaptive max_tokens enabled "
|
||||
f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
|
||||
)
|
||||
|
||||
_adaptive_max_lock = None # initialized lazily
|
||||
|
||||
def _generate_only(self, inputs, rank0_only=False):
|
||||
"""Override to set per-batch max_tokens based on ground-truth length.
|
||||
|
||||
Uses a lock to prevent race conditions in async mode where concurrent
|
||||
BG threads could interleave mutations of max_completion_length.
|
||||
"""
|
||||
import threading
|
||||
|
||||
args = self.args
|
||||
if (
|
||||
args.ebft_adaptive_max_tokens
|
||||
and hasattr(self, "vllm_generation")
|
||||
and inputs
|
||||
):
|
||||
gt_texts = [
|
||||
x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
|
||||
]
|
||||
if gt_texts:
|
||||
gt_token_counts = [
|
||||
len(self.processing_class.encode(gt, add_special_tokens=False))
|
||||
for gt in gt_texts
|
||||
]
|
||||
multiplier = args.ebft_gt_length_multiplier
|
||||
max_completion = self.vllm_generation.max_completion_length
|
||||
adaptive_max = max(
|
||||
min(int(c * multiplier), max_completion) for c in gt_token_counts
|
||||
)
|
||||
adaptive_max = max(adaptive_max, 64)
|
||||
|
||||
if self._adaptive_max_lock is None:
|
||||
self._adaptive_max_lock = threading.Lock()
|
||||
with self._adaptive_max_lock:
|
||||
original = self.vllm_generation.max_completion_length
|
||||
self.vllm_generation.max_completion_length = adaptive_max
|
||||
try:
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
finally:
|
||||
self.vllm_generation.max_completion_length = original
|
||||
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
|
||||
@torch.no_grad()
|
||||
def _feature_matching_reward(
|
||||
self,
|
||||
prompts: list,
|
||||
completions: list,
|
||||
ground_truth: list[str] | None = None,
|
||||
remaining_turns: list | None = None,
|
||||
**kwargs,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Compute feature-matching rewards for generated completions.
|
||||
|
||||
This is called by GRPOTrainer's _generate_and_score_completions()
|
||||
as a standard reward function. The `ground_truth` field comes from
|
||||
the dataset via reward_kwargs.
|
||||
|
||||
For multi-turn conversations, `remaining_turns` contains the subsequent
|
||||
user/assistant turn pairs. When present, we do sequential rollouts:
|
||||
generate each assistant turn conditioned on history + previous generations,
|
||||
then compute feature-matching rewards on the full generated conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt strings/messages
|
||||
completions: List of generated completion strings
|
||||
ground_truth: List of reference completion strings (from dataset)
|
||||
remaining_turns: List of remaining conversation turns after the
|
||||
first assistant turn (for multi-turn rollouts)
|
||||
|
||||
Returns:
|
||||
List of scalar rewards, one per completion
|
||||
"""
|
||||
if ground_truth is None:
|
||||
LOG.warning("No ground_truth field in dataset — using zero rewards")
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
device = self.accelerator.device
|
||||
args = self.args
|
||||
num_gens = self.num_generations
|
||||
|
||||
# --- Multi-turn sequential rollout ---
|
||||
# If remaining_turns is provided, generate subsequent assistant turns
|
||||
# by calling vLLM for each turn, building up the full conversation.
|
||||
if remaining_turns is not None and hasattr(self, "vllm_generation"):
|
||||
completions = self._sequential_rollout(
|
||||
prompts, completions, remaining_turns, num_gens
|
||||
)
|
||||
|
||||
# --- Tokenize generated sequences: prompt + completion ---
|
||||
gen_texts = []
|
||||
gen_prompt_texts = []
|
||||
for p, c in zip(prompts, completions, strict=True):
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
prompt_text = p
|
||||
if isinstance(c, list):
|
||||
comp_text = c[0].get("content", "") if c else ""
|
||||
else:
|
||||
comp_text = c
|
||||
gen_texts.append(prompt_text + comp_text)
|
||||
gen_prompt_texts.append(prompt_text)
|
||||
|
||||
gen_encoded = self.processing_class(
|
||||
text=gen_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gen_ids = gen_encoded["input_ids"].to(device)
|
||||
gen_mask = gen_encoded["attention_mask"].to(device)
|
||||
|
||||
# Compute prompt lengths for completion_mean pooling
|
||||
gen_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gen_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Tokenize ground-truth sequences: prompt + ground_truth ---
|
||||
# For multi-turn (remaining_turns present), render the full GT conversation
|
||||
# through the chat template to preserve role markers between turns.
|
||||
gt_texts = []
|
||||
gt_prompt_texts = []
|
||||
for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
|
||||
if i % num_gens != 0:
|
||||
continue # Only need one GT per prompt group
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# Multi-turn: build full GT conversation with remaining turns
|
||||
if remaining_turns is not None:
|
||||
prompt_idx = i // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx]
|
||||
if prompt_idx < len(remaining_turns)
|
||||
else []
|
||||
)
|
||||
if turns:
|
||||
gt_conv = list(p) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(turns)
|
||||
full_gt_text = self.processing_class.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
gt_texts.append(full_gt_text)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
continue
|
||||
else:
|
||||
prompt_text = p
|
||||
gt_texts.append(prompt_text + gt)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
|
||||
gt_encoded = self.processing_class(
|
||||
text=gt_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gt_ids = gt_encoded["input_ids"].to(device)
|
||||
gt_mask = gt_encoded["attention_mask"].to(device)
|
||||
|
||||
gt_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gt_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Extract features from frozen feature network ---
|
||||
# INVARIANT: disable_adapter() yields the unmodified base weights because
|
||||
# _sync_peft_weights_no_merge and _sync_lora_adapter never call
|
||||
# merge_adapter() — they compute merged weights as new tensors or save
|
||||
# the adapter to filesystem. Base weights are never modified in-place.
|
||||
if self._share_feature_weights:
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
feature_ctx = unwrapped.disable_adapter()
|
||||
else:
|
||||
unwrapped = self.feature_network
|
||||
feature_ctx = contextlib.nullcontext()
|
||||
|
||||
with feature_ctx:
|
||||
was_training = unwrapped.training
|
||||
unwrapped.eval()
|
||||
gen_hidden = extract_hidden_states(
|
||||
unwrapped, gen_ids, gen_mask, self.feature_layer_indices
|
||||
)
|
||||
gt_hidden = extract_hidden_states(
|
||||
unwrapped, gt_ids, gt_mask, self.feature_layer_indices
|
||||
)
|
||||
if was_training:
|
||||
unwrapped.train()
|
||||
|
||||
# --- Pool to sequence-level embeddings ---
|
||||
gen_emb = apply_embed_method(
|
||||
gen_hidden,
|
||||
args.ebft_embed_method,
|
||||
gen_mask,
|
||||
prompt_lengths=gen_prompt_lengths,
|
||||
)
|
||||
gt_emb = apply_embed_method(
|
||||
gt_hidden,
|
||||
args.ebft_embed_method,
|
||||
gt_mask,
|
||||
prompt_lengths=gt_prompt_lengths,
|
||||
)
|
||||
|
||||
# --- Optional whitening ---
|
||||
batch_size = gen_emb.shape[0]
|
||||
if args.ebft_use_whitening and batch_size > 1:
|
||||
num_prompts = batch_size // num_gens
|
||||
gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
|
||||
whitened_gen_list = []
|
||||
whitened_gt_list = []
|
||||
for i in range(num_prompts):
|
||||
w_gen, w_gt = whiten_embeddings_batched(
|
||||
gen_reshaped[i], gt_emb[i : i + 1]
|
||||
)
|
||||
whitened_gen_list.append(w_gen)
|
||||
whitened_gt_list.append(w_gt)
|
||||
gen_emb = torch.cat(whitened_gen_list, dim=0)
|
||||
gt_emb = torch.cat(whitened_gt_list, dim=0)
|
||||
else:
|
||||
gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
|
||||
gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
|
||||
|
||||
# Repeat gt_emb: each GT repeated num_generations times
|
||||
gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
|
||||
|
||||
# --- Compute rewards ---
|
||||
alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
|
||||
diversity = get_diversity_rewards(gen_emb, num_gens)
|
||||
|
||||
# Scale by 2 per paper equation (7):
|
||||
# r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
|
||||
alignment = alignment * 2
|
||||
diversity = diversity * 2
|
||||
|
||||
rewards = (
|
||||
args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
|
||||
)
|
||||
|
||||
# Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
|
||||
gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
|
||||
mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
|
||||
cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
|
||||
|
||||
# Log feature-matching metrics to console and wandb
|
||||
_align = alignment.mean().item()
|
||||
_divers = diversity.mean().item()
|
||||
_reward = rewards.mean().item()
|
||||
_cfm = cfm_loss.item()
|
||||
|
||||
LOG.info(
|
||||
f"ebft reward | "
|
||||
f"align {_align:+.3f} ^ | "
|
||||
f"divers {_divers:+.3f} v | "
|
||||
f"cfm {_cfm:.3f} v | "
|
||||
f"reward {_reward:+.3f} ^"
|
||||
)
|
||||
|
||||
# Log to wandb via trainer's _metrics (picked up by GRPO's logging)
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if hasattr(self, "_metrics"):
|
||||
self._metrics[mode]["ebft/alignment"].append(_align)
|
||||
self._metrics[mode]["ebft/diversity"].append(_divers)
|
||||
self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
|
||||
self._metrics[mode]["ebft/reward"].append(_reward)
|
||||
|
||||
return rewards.cpu().tolist()
|
||||
|
||||
@torch.no_grad()
|
||||
def _sequential_rollout(
|
||||
self,
|
||||
prompts: list,
|
||||
first_completions: list,
|
||||
remaining_turns: list,
|
||||
num_gens: int,
|
||||
) -> list:
|
||||
"""
|
||||
Extend single-turn completions into multi-turn conversations.
|
||||
|
||||
For each prompt group, takes the first generated assistant turn and
|
||||
sequentially generates subsequent assistant turns by calling vLLM,
|
||||
building up a full multi-turn conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt message lists (repeated num_gens times)
|
||||
first_completions: List of generated first-turn completions
|
||||
remaining_turns: List of remaining turn pairs after first assistant turn.
|
||||
Each element is a list of dicts: [{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "...GT..."}]
|
||||
num_gens: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Extended completions incorporating all generated turns
|
||||
"""
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
max_tokens = getattr(self.args, "max_completion_length", 256)
|
||||
temperature = getattr(self.args, "temperature", 0.7)
|
||||
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
|
||||
|
||||
extended_completions = []
|
||||
|
||||
for idx in range(len(prompts)):
|
||||
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
|
||||
first_comp = first_completions[idx]
|
||||
|
||||
# Extract first completion text
|
||||
if isinstance(first_comp, list):
|
||||
first_text = first_comp[0].get("content", "") if first_comp else ""
|
||||
else:
|
||||
first_text = first_comp
|
||||
|
||||
# Get remaining turns for this prompt (same for all num_gens copies)
|
||||
prompt_idx = idx // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
|
||||
)
|
||||
|
||||
if not turns:
|
||||
extended_completions.append(first_text)
|
||||
continue
|
||||
|
||||
# Build conversation with generated first turn
|
||||
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
|
||||
|
||||
# Generate subsequent turns
|
||||
for turn in turns:
|
||||
if turn["role"] == "user":
|
||||
conv.append(turn)
|
||||
elif turn["role"] == "assistant":
|
||||
try:
|
||||
result = vllm_client.chat(
|
||||
messages=[conv],
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
generation_kwargs=gen_kwargs,
|
||||
)
|
||||
gen_ids = result.get("completion_ids", [[]])[0]
|
||||
gen_text = self.processing_class.decode(
|
||||
gen_ids, skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Multi-turn rollout generation failed: {e}")
|
||||
gen_text = ""
|
||||
|
||||
conv.append({"role": "assistant", "content": gen_text})
|
||||
|
||||
# Render full conversation through chat template, then extract
|
||||
# everything after the original prompt as the "completion" text.
|
||||
# This preserves role markers and formatting between turns.
|
||||
full_rendered = self.processing_class.apply_chat_template(
|
||||
conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
prompt_rendered = self.processing_class.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
completion_text = full_rendered[len(prompt_rendered) :]
|
||||
extended_completions.append(completion_text)
|
||||
|
||||
return extended_completions
|
||||
|
||||
|
||||
class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
|
||||
"""EBFT trainer using synchronous GRPO (standard vLLM generation)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
|
||||
"""EBFT trainer using async GRPO (prefetches next batch during training)."""
|
||||
|
||||
pass
|
||||
@@ -628,13 +628,21 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.
|
||||
# The communicator is not needed because weight sync happens via filesystem + HTTP,
|
||||
# and it fails when vLLM and a trainer rank share the same CUDA device.
|
||||
# Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
|
||||
# merged weight sync. NCCL is only needed for the standard update_named_param
|
||||
# path which broadcasts tensors through the communicator.
|
||||
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
|
||||
if training_args is not None and getattr(
|
||||
training_args, "vllm_lora_sync", False
|
||||
):
|
||||
_skip_nccl = False
|
||||
if training_args is not None:
|
||||
if getattr(training_args, "vllm_lora_sync", False):
|
||||
_skip_nccl = True # LoRA sync uses filesystem + HTTP
|
||||
elif getattr(training_args, "async_prefetch", False):
|
||||
# Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
|
||||
# init_communicator allocates device tensors on rank 0 only, which
|
||||
# causes DDP to see different param counts across ranks.
|
||||
# The communicator is initialized lazily on first weight sync instead.
|
||||
_skip_nccl = True
|
||||
if _skip_nccl:
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
_orig_init_vllm = VLLMGeneration._init_vllm
|
||||
@@ -661,7 +669,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
VLLMGeneration._init_vllm = _init_vllm_no_communicator
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
super().__init__(*args, **kwargs)
|
||||
finally:
|
||||
# Restore original _init_vllm so other trainers aren't affected
|
||||
if _skip_nccl:
|
||||
VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
|
||||
|
||||
# FP8 models: zero out the pad token embedding so that padding
|
||||
# positions have zero hidden states throughout the network.
|
||||
@@ -780,11 +793,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._executor = None
|
||||
|
||||
def _submit_generation(self):
|
||||
"""Submit the next background generation job."""
|
||||
"""Submit the next background generation job.
|
||||
|
||||
With multi-process (DDP/FSDP), only rank 0 generates to avoid
|
||||
cross-rank NCCL collectives from background threads. Non-rank-0
|
||||
processes enqueue a sentinel ``None`` that is replaced by a
|
||||
broadcast in ``_prepare_inputs_legacy_async``.
|
||||
"""
|
||||
rank0_only = self.accelerator.num_processes > 1
|
||||
if rank0_only and not self.accelerator.is_main_process:
|
||||
# Non-rank-0: nothing to generate; enqueue a resolved None future
|
||||
f: concurrent.futures.Future = concurrent.futures.Future()
|
||||
f.set_result(None)
|
||||
self._async_queue.put(f)
|
||||
return
|
||||
batch = next(self._prompt_iter)
|
||||
future = self._executor.submit(self._generate_only, batch)
|
||||
future = self._executor.submit(self._generate_only, batch, rank0_only)
|
||||
self._async_queue.put(future)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcast rollout (legacy async, multi-process)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_rollout(self, rollout: dict | None) -> dict:
|
||||
"""Broadcast a rank0-only rollout dict to all ranks (main thread).
|
||||
|
||||
Rank 0 has the full rollout dict from ``_generate_only``; other ranks
|
||||
have ``None``. After broadcast, tensors are moved to each rank's
|
||||
local device.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
obj_list = [rollout if self.accelerator.is_main_process else None]
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
rollout = obj_list[0]
|
||||
assert rollout is not None, "broadcast_object_list failed to deliver rollout"
|
||||
|
||||
# Move tensors to local device (broadcast deserializes to CPU)
|
||||
device = self.accelerator.device
|
||||
for key, val in rollout.items():
|
||||
if isinstance(val, torch.Tensor) and val.device != device:
|
||||
rollout[key] = val.to(device)
|
||||
|
||||
return rollout
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight sync
|
||||
# ------------------------------------------------------------------
|
||||
@@ -796,14 +848,18 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
for Float8), and also safe for concurrent use since it never modifies base
|
||||
weights in-place.
|
||||
"""
|
||||
model = self.vllm_generation.model
|
||||
accelerator = self.vllm_generation.accelerator
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
fix_name = self.vllm_generation._fix_param_name_to_vllm
|
||||
|
||||
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
|
||||
return
|
||||
|
||||
# In multi-GPU async mode, we skip NCCL communicator init to avoid
|
||||
# DDP param count mismatch and NCCL device conflicts. Weight sync
|
||||
# uses the HTTP-only fallback in batch_update_named_params instead.
|
||||
|
||||
model = self.vllm_generation.model
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
fix_name = self.vllm_generation._fix_param_name_to_vllm
|
||||
|
||||
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
|
||||
lora_info = {}
|
||||
for mod_name, module in model.base_model.model.named_modules():
|
||||
@@ -826,10 +882,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
weight_name = pname.replace(".weight_scale_inv", ".weight")
|
||||
scale_inv_lookup[weight_name] = pparam.data
|
||||
|
||||
# Iterate all parameters, computing merged weights for LoRA layers.
|
||||
# Skip LoRA-specific params and FP8 scale params (scales will be
|
||||
# recomputed by vLLM when it receives the merged bf16 weight).
|
||||
# Only sync parameters that have LoRA modifications — skip unchanged
|
||||
# base weights to avoid OOM on the vLLM GPU from allocating the entire
|
||||
# model's worth of NCCL receive buffers.
|
||||
params_to_sync = []
|
||||
compute_dtype = torch.bfloat16
|
||||
for name, param in model.named_parameters():
|
||||
vllm_name = name.removeprefix("base_model.model.").replace(
|
||||
".base_layer", ""
|
||||
@@ -838,52 +895,58 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
continue
|
||||
if "original_module" in vllm_name:
|
||||
continue
|
||||
# Skip FP8 quantization scale parameters - they are recomputed
|
||||
# on the vLLM side when we update the weight itself
|
||||
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
|
||||
continue
|
||||
if not vllm_name.endswith(".weight"):
|
||||
continue
|
||||
# fix_name strips modules_to_save.default. prefix
|
||||
raw_mod_path = vllm_name[: -len(".weight")]
|
||||
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
|
||||
# Sync weights that have LoRA adapters OR are modules_to_save
|
||||
is_lora = mod_path in lora_info
|
||||
is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
|
||||
if not is_lora and not is_modules_to_save:
|
||||
continue
|
||||
|
||||
data = param.data
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
if vllm_name.endswith(".weight"):
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
# Block dequantization: weight * scale_inv (with broadcasting)
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
# Block-quantized: scale_inv shape (rows/block, cols/block)
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr # block height
|
||||
bc = fp8_bf16.shape[1] // sc # block width
|
||||
# Reshape → multiply by block scale → reshape back
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
# Per-tensor or per-channel scale
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
# FP8 but no scale found - just cast (lossy)
|
||||
data = data.to(compute_dtype)
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr
|
||||
bc = fp8_bf16.shape[1] // sc
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
data = data.to(compute_dtype)
|
||||
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
if mod_path in lora_info:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
data = merged
|
||||
if is_lora:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
params_to_sync.append((vllm_name, merged))
|
||||
else:
|
||||
# modules_to_save: send raw weight (no LoRA merge needed)
|
||||
params_to_sync.append((vllm_name, data.to(compute_dtype)))
|
||||
|
||||
params_to_sync.append((vllm_name, data))
|
||||
|
||||
# Batch sync all params in one HTTP+NCCL call (vs individual calls)
|
||||
# Batch sync only LoRA-modified params via HTTP+NCCL
|
||||
if params_to_sync:
|
||||
sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
|
||||
logger.info(
|
||||
f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
|
||||
)
|
||||
vllm_client.batch_update_named_params(params_to_sync)
|
||||
|
||||
# Reset prefix cache after weight update
|
||||
@@ -950,6 +1013,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
url = f"{vllm_client.base_url}/set_lora_adapter/"
|
||||
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
|
||||
response = requests.post(
|
||||
url,
|
||||
json={
|
||||
@@ -957,7 +1021,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_path": adapter_path,
|
||||
},
|
||||
timeout=30,
|
||||
timeout=sync_timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
@@ -1008,11 +1072,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
if step == 0:
|
||||
logger.info("Skipping LoRA sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly
|
||||
self._sync_lora_adapter()
|
||||
else:
|
||||
@@ -1088,7 +1152,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# Background-thread generation (no scoring)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_single_turn(self, prompts, **kwargs):
|
||||
def _generate_single_turn(self, prompts, *args, **kwargs):
|
||||
"""Override to prevent weight sync from background thread and to use
|
||||
no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
|
||||
is_bg = threading.current_thread() is not threading.main_thread()
|
||||
@@ -1121,7 +1185,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._patched_sync_weights = True
|
||||
|
||||
try:
|
||||
return super()._generate_single_turn(prompts, **kwargs)
|
||||
return super()._generate_single_turn(prompts, *args, **kwargs)
|
||||
finally:
|
||||
if saved_step is not None:
|
||||
self._last_loaded_step = saved_step
|
||||
@@ -1165,9 +1229,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
output = vg.vllm_client.chat(
|
||||
messages=unique_prompts,
|
||||
**sampling_params,
|
||||
chat_template_kwargs=vg.chat_template_kwargs,
|
||||
tools=vg.tools,
|
||||
chat_template=vg.chat_template,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
tools=self.tools,
|
||||
chat_template=getattr(self, "chat_template", None),
|
||||
)
|
||||
else:
|
||||
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
|
||||
@@ -1584,10 +1648,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
logps_diff = per_token_logps_diff
|
||||
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
data["importance_sampling_ratio"] = is_ratio
|
||||
|
||||
# --- Collect rewards (launched before logprobs, should be done) ---
|
||||
@@ -1906,10 +1972,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
seq_is = is_mode in ("sequence_mask", "sequence_truncate")
|
||||
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
# Symmetric floor clamp (matches non-streaming path at line ~1651)
|
||||
is_floor = 1.0 / is_cap
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
if "importance_sampling_ratio" not in data:
|
||||
total = len(data["prompt_ids"])
|
||||
shape = (total, 1) if seq_is else (total, is_ratio.size(1))
|
||||
@@ -2280,6 +2349,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
rollout = future.result()
|
||||
self._submit_generation()
|
||||
|
||||
# With multi-process, only rank 0 generated. Broadcast to all ranks.
|
||||
if self.accelerator.num_processes > 1:
|
||||
rollout = self._broadcast_rollout(rollout)
|
||||
|
||||
if self.args.streaming_partial_batch:
|
||||
micro_batches = self._score_streaming(rollout)
|
||||
else:
|
||||
|
||||
@@ -145,10 +145,10 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.axolotl_cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
if wandb.run is not None: # type: ignore[attr-defined]
|
||||
wandb.log( # type: ignore[attr-defined]
|
||||
{
|
||||
"generated_samples": wandb.Table(
|
||||
"generated_samples": wandb.Table( # type: ignore[attr-defined]
|
||||
columns=[
|
||||
"step",
|
||||
"original",
|
||||
|
||||
@@ -20,46 +20,93 @@ LOG = logging.getLogger(__name__)
|
||||
def _batch_update_named_params(
|
||||
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
|
||||
):
|
||||
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
|
||||
from transformers import is_torch_xpu_available
|
||||
"""Batched weight sync — uses NCCL if communicator available, HTTP otherwise."""
|
||||
has_communicator = getattr(self, "communicator", None) is not None
|
||||
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
if has_communicator:
|
||||
# Fast path: metadata via HTTP, tensors via NCCL
|
||||
from transformers import is_torch_xpu_available
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(url, json={"params": param_metadata})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{
|
||||
"name": name,
|
||||
"dtype": str(weights.dtype),
|
||||
"shape": list(weights.shape),
|
||||
}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(
|
||||
url, json={"params": param_metadata}, timeout=120
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
else:
|
||||
# HTTP-only path: encode tensor data in request body (no NCCL needed).
|
||||
# Batch by byte size to avoid huge HTTP payloads.
|
||||
MAX_BYTES_PER_REQUEST = 10 * 1024 * 1024 # 10 MB
|
||||
HTTP_TIMEOUT = 120 # seconds per request
|
||||
|
||||
payload: list[dict] = []
|
||||
payload_bytes = 0
|
||||
url = f"{self.base_url}/http_update_weights/"
|
||||
|
||||
def _flush(p: list[dict]) -> None:
|
||||
if not p:
|
||||
return
|
||||
response = self.session.post(url, json={"params": p}, timeout=HTTP_TIMEOUT)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
from axolotl.utils.weight_serde import encode_for_http
|
||||
|
||||
for name, weights in params:
|
||||
entry = encode_for_http(name, weights)
|
||||
entry_bytes = weights.nelement() * weights.element_size()
|
||||
|
||||
# Flush current batch if adding this entry would exceed limit
|
||||
if payload and payload_bytes + entry_bytes > MAX_BYTES_PER_REQUEST:
|
||||
_flush(payload)
|
||||
payload = []
|
||||
payload_bytes = 0
|
||||
|
||||
payload.append(entry)
|
||||
payload_bytes += entry_bytes
|
||||
|
||||
_flush(payload) # send remaining
|
||||
|
||||
|
||||
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
|
||||
|
||||
9
src/axolotl/prompt_strategies/ebft/__init__.py
Normal file
9
src/axolotl/prompt_strategies/ebft/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
module for EBFT style dataset transform strategies
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module_base="axolotl.prompt_strategies.ebft")
|
||||
129
src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
Normal file
129
src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Dataset transform for multi-turn chat data with structured EBFT (vLLM mode).
|
||||
|
||||
Three variants:
|
||||
|
||||
1. `transform` — Uses the FIRST assistant turn as the generation target.
|
||||
Passes remaining turns as `remaining_turns` for sequential rollout.
|
||||
The trainer generates turn 1 via GRPO/vLLM, then sequentially generates
|
||||
subsequent assistant turns, comparing the full conversation to GT.
|
||||
|
||||
2. `transform_last_turn` — Uses the LAST assistant turn as the target.
|
||||
Simplest approach: the full conversation history is the prompt.
|
||||
|
||||
3. `transform_all_turns` — Explodes each conversation into N examples
|
||||
(one per assistant turn). Each turn is an independent training example.
|
||||
Use with batched=True.
|
||||
|
||||
Supports OpenAI chat format:
|
||||
{"messages": [{"role": ..., "content": ...}, ...]}
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
"""Multi-turn with sequential rollout.
|
||||
|
||||
Returns the first assistant turn as ground_truth, plus remaining_turns
|
||||
for the trainer to do sequential rollout generation.
|
||||
"""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if not messages:
|
||||
return {"prompt": [], "ground_truth": ""}
|
||||
|
||||
# Split at first assistant turn
|
||||
prompt_msgs = []
|
||||
first_gt = None
|
||||
remaining = []
|
||||
|
||||
found_first = False
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant" and not found_first:
|
||||
first_gt = msg["content"]
|
||||
found_first = True
|
||||
elif found_first:
|
||||
remaining.append(msg)
|
||||
else:
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
if first_gt is None:
|
||||
return {"prompt": prompt_msgs, "ground_truth": ""}
|
||||
|
||||
# Store only the first assistant turn as ground_truth. The full multi-turn
|
||||
# GT is reconstructed in the reward function via chat template rendering
|
||||
# (using remaining_turns), which preserves role markers between turns.
|
||||
return {
|
||||
"prompt": prompt_msgs,
|
||||
"ground_truth": first_gt,
|
||||
"remaining_turns": remaining,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
|
||||
|
||||
def transform_last_turn(cfg, **kwargs):
|
||||
"""Single-turn: use the last assistant turn as the generation target."""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if not messages:
|
||||
return {"prompt": [], "ground_truth": ""}
|
||||
|
||||
# Find all assistant turns
|
||||
history = []
|
||||
last_prompt = []
|
||||
last_gt = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
last_prompt = list(history)
|
||||
last_gt = msg["content"]
|
||||
history.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": last_prompt,
|
||||
"ground_truth": last_gt,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
|
||||
|
||||
def transform_all_turns(cfg, **kwargs):
|
||||
"""Explode: one example per assistant turn.
|
||||
|
||||
Use with datasets.map(batched=True) to produce N examples from
|
||||
each N-turn conversation.
|
||||
|
||||
Usage in YAML:
|
||||
type: ebft_chat_multiturn.transform_all_turns
|
||||
"""
|
||||
|
||||
def transform_fn(examples, tokenizer=None):
|
||||
all_prompts = []
|
||||
all_ground_truths = []
|
||||
|
||||
messages_list = examples.get("messages", examples.get("conversations", []))
|
||||
|
||||
for messages in messages_list:
|
||||
history = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
all_prompts.append(list(history))
|
||||
all_ground_truths.append(msg["content"])
|
||||
history.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": all_prompts,
|
||||
"ground_truth": all_ground_truths,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
"batched": True,
|
||||
}
|
||||
20
src/axolotl/prompt_strategies/ebft/ebft_opencode.py
Normal file
20
src/axolotl/prompt_strategies/ebft/ebft_opencode.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Dataset transform for nvidia/OpenCodeInstruct with EBFT structured mode.
|
||||
|
||||
Maps the dataset's `input` (prompt) and `output` (code solution) fields
|
||||
to the format expected by the EBFT trainer (prompt + ground_truth).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "user", "content": example["input"]},
|
||||
],
|
||||
"ground_truth": example["output"],
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
319
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
Normal file
319
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Dataset transform for reasoning/thinking datasets with EBFT.
|
||||
|
||||
Handles datasets where assistant responses contain <think>...</think> reasoning
|
||||
traces (e.g., TeichAI/Claude-Opus-4.6-Reasoning, Qwen3.5 thinking mode outputs).
|
||||
|
||||
Two variants:
|
||||
|
||||
1. `transform` — For structured EBFT (vLLM mode):
|
||||
Returns prompt + ground_truth with thinking tags preserved.
|
||||
Feature matching compares full responses (thinking + answer).
|
||||
|
||||
2. `transform_answer_only` — For structured EBFT (vLLM mode):
|
||||
Strips <think>...</think> from ground_truth, so feature matching
|
||||
only scores the final answer portion. Use when reasoning chains
|
||||
can vary but the answer should match.
|
||||
|
||||
3. `transform_strided` — For strided EBFT:
|
||||
Tokenizes the full conversation with thinking traces.
|
||||
Optionally masks thinking tokens from CE loss (labels=-100 for think spans)
|
||||
while still placing anchors in thinking regions for feature matching.
|
||||
|
||||
All variants work with OpenAI chat format:
|
||||
{"messages": [{"role": "...", "content": "<think>...</think>Answer"}]}
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _strip_thinking(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks from text."""
|
||||
return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
def _extract_thinking(text: str) -> tuple[str, str]:
|
||||
"""Split text into (thinking, answer) parts."""
|
||||
match = re.search(r"<think>(.*?)</think>\s*(.*)", text, flags=re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip(), match.group(2).strip()
|
||||
return "", text.strip()
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
"""Full response including thinking traces for feature matching.
|
||||
|
||||
For datasets where assistant content has <think>...</think> tags in the
|
||||
content field. The ground_truth includes the full content (thinking + answer).
|
||||
"""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(messages[:msg_idx])
|
||||
ground_truth = msg["content"]
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_split_thinking(cfg, **kwargs):
|
||||
"""Split <think> tags into reasoning_content field for native chat template handling.
|
||||
|
||||
For datasets where thinking is embedded in the content field as <think>...</think>.
|
||||
Splits it into separate reasoning_content and content fields so the model's
|
||||
chat template can format it natively (e.g., Qwen3.5's reasoning_content support).
|
||||
|
||||
The prompt messages are passed through with reasoning_content properly split,
|
||||
so vLLM generation with enable_thinking=true produces comparable outputs.
|
||||
The ground_truth is the full assistant response (thinking + answer) for
|
||||
feature matching.
|
||||
|
||||
Also works for:
|
||||
- <reasoning>...</reasoning> tags
|
||||
- <|begin_of_thought|>...<|end_of_thought|> tags
|
||||
"""
|
||||
_THINKING_PAIRS = [
|
||||
("<think>", "</think>"),
|
||||
("<reasoning>", "</reasoning>"),
|
||||
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
||||
]
|
||||
|
||||
def _split_msg_thinking(msg):
|
||||
"""Split thinking from assistant message content into reasoning_content.
|
||||
|
||||
Always includes reasoning_content key on assistant messages (empty string
|
||||
if no thinking tags found) to ensure consistent HF dataset schema across
|
||||
all examples in a batch.
|
||||
"""
|
||||
if msg["role"] != "assistant":
|
||||
return msg
|
||||
content = msg.get("content", "")
|
||||
# Already has reasoning_content — pass through
|
||||
if "reasoning_content" in msg:
|
||||
return msg
|
||||
for open_tag, close_tag in _THINKING_PAIRS:
|
||||
if open_tag in content and close_tag in content:
|
||||
start = content.find(open_tag)
|
||||
end = content.find(close_tag)
|
||||
thinking = content[start + len(open_tag) : end].strip()
|
||||
answer = content[end + len(close_tag) :].strip()
|
||||
return {
|
||||
**msg,
|
||||
"reasoning_content": thinking,
|
||||
"content": answer,
|
||||
}
|
||||
# No thinking tags — still add reasoning_content for schema consistency
|
||||
return {**msg, "reasoning_content": ""}
|
||||
|
||||
def _normalize_msg(msg):
|
||||
"""Ensure every message has {role, content, reasoning_content} for HF schema consistency."""
|
||||
return {
|
||||
"role": msg.get("role", ""),
|
||||
"content": msg.get("content", ""),
|
||||
"reasoning_content": msg.get("reasoning_content", ""),
|
||||
}
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
# Split thinking in all assistant messages, then normalize schema
|
||||
split_messages = [_normalize_msg(_split_msg_thinking(m)) for m in messages]
|
||||
|
||||
# Build prompt (all messages except last assistant) and ground_truth
|
||||
prompt_msgs = []
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg in split_messages:
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(prompt_msgs)
|
||||
# ground_truth is the FULL content for feature matching
|
||||
thinking = msg.get("reasoning_content", "")
|
||||
answer = msg.get("content", "")
|
||||
if thinking:
|
||||
ground_truth = f"<think>\n{thinking}\n</think>\n\n{answer}"
|
||||
else:
|
||||
ground_truth = answer
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else split_messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_answer_only(cfg, **kwargs):
|
||||
"""Strip thinking from ground_truth — match features on answer only."""
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
prompt_msgs = []
|
||||
prompt_msgs_snapshot = None
|
||||
ground_truth = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
prompt_msgs_snapshot = list(prompt_msgs)
|
||||
ground_truth = _strip_thinking(msg["content"])
|
||||
prompt_msgs.append(msg)
|
||||
|
||||
return {
|
||||
"prompt": prompt_msgs_snapshot
|
||||
if prompt_msgs_snapshot is not None
|
||||
else messages[:-1],
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
|
||||
|
||||
def transform_strided(cfg, **kwargs):
|
||||
"""For strided EBFT: tokenize with thinking, optionally mask think tokens from CE loss.
|
||||
|
||||
Config options (via cfg):
|
||||
- ebft.mask_thinking_ce: bool (default False)
|
||||
If True, set labels=-100 for tokens inside <think>...</think> blocks.
|
||||
Feature matching still uses these positions (anchors are placed everywhere
|
||||
in the completion span). Only CE auxiliary loss is affected.
|
||||
"""
|
||||
seq_len = cfg.sequence_len
|
||||
mask_thinking = False
|
||||
if cfg.ebft and hasattr(cfg.ebft, "mask_thinking_ce"):
|
||||
mask_thinking = cfg.ebft.mask_thinking_ce
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if tokenizer is None:
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
return {"prompt": m["content"]}
|
||||
return {"prompt": str(messages)}
|
||||
|
||||
pad_id = (
|
||||
tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Tokenize the full conversation with the chat template
|
||||
full_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
full_enc = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
input_ids = full_enc["input_ids"]
|
||||
|
||||
# Build labels: -100 for non-assistant tokens
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# Find assistant turn boundaries using incremental tokenization.
|
||||
# Only the FINAL assistant turn is marked as trainable.
|
||||
prefix_messages = []
|
||||
final_start = None
|
||||
final_end = None
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
prefix_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prefix_ids = tokenizer(
|
||||
prefix_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
start = len(prefix_ids)
|
||||
|
||||
prefix_messages.append(msg)
|
||||
with_turn_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
with_turn_ids = tokenizer(
|
||||
with_turn_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
end = len(with_turn_ids)
|
||||
|
||||
# Record this turn's boundaries; only the last one will be used
|
||||
final_start = start
|
||||
final_end = end
|
||||
else:
|
||||
prefix_messages.append(msg)
|
||||
|
||||
# Mark only the final assistant turn as trainable
|
||||
if final_start is not None and final_end is not None:
|
||||
for i in range(final_start, min(final_end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
# Optionally mask <think>...</think> tokens within this turn.
|
||||
# Find think spans by scanning for <think> and </think> token IDs
|
||||
# directly in the input_ids (robust to tokenization alignment).
|
||||
if mask_thinking:
|
||||
think_open_id = tokenizer.convert_tokens_to_ids("<think>")
|
||||
think_close_id = tokenizer.convert_tokens_to_ids("</think>")
|
||||
if think_open_id != tokenizer.unk_token_id:
|
||||
# Scan from before the assistant turn start to catch
|
||||
# <think> tags that are part of the template prefix
|
||||
scan_start = max(0, final_start - 5)
|
||||
in_think = False
|
||||
for i in range(scan_start, min(final_end, len(labels))):
|
||||
if input_ids[i] == think_open_id:
|
||||
in_think = True
|
||||
if in_think and i >= final_start:
|
||||
labels[i] = -100
|
||||
if input_ids[i] == think_close_id:
|
||||
in_think = False
|
||||
if i >= final_start:
|
||||
labels[i] = -100
|
||||
|
||||
# Derive prompt_length
|
||||
prompt_length = len(input_ids)
|
||||
for i, lbl in enumerate(labels):
|
||||
if lbl != -100:
|
||||
prompt_length = i
|
||||
break
|
||||
|
||||
# Pad
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
110
src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
Normal file
110
src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Dataset transform for multi-turn chat data with strided EBFT.
|
||||
|
||||
Tokenizes conversations using the model's chat template, producing input_ids
|
||||
with labels=-100 for system/user turns and real labels for assistant turns.
|
||||
The strided trainer places anchors only within assistant completion spans.
|
||||
|
||||
Works with datasets in OpenAI chat format:
|
||||
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
messages = example.get("messages", example.get("conversations", []))
|
||||
|
||||
if tokenizer is None:
|
||||
# For preview: just return the first user message
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
return {"prompt": m["content"]}
|
||||
return {"prompt": str(messages)}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize the full conversation with the chat template
|
||||
full_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
full_enc = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
input_ids = full_enc["input_ids"]
|
||||
|
||||
# Build labels: -100 for everything except assistant turns.
|
||||
# Strategy: tokenize incrementally to find assistant turn boundaries.
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# Tokenize prefix up to each assistant turn to find boundaries
|
||||
prefix_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "assistant":
|
||||
# Tokenize prefix (everything before this assistant turn + generation prompt)
|
||||
prefix_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prefix_ids = tokenizer(
|
||||
prefix_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
start = len(prefix_ids)
|
||||
|
||||
# Tokenize prefix + this assistant turn
|
||||
prefix_messages.append(msg)
|
||||
with_turn_text = tokenizer.apply_chat_template(
|
||||
prefix_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
with_turn_ids = tokenizer(
|
||||
with_turn_text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
end = len(with_turn_ids)
|
||||
|
||||
# Mark assistant tokens as trainable
|
||||
for i in range(start, min(end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
else:
|
||||
prefix_messages.append(msg)
|
||||
|
||||
# Derive prompt_length as the position of the first non-masked label
|
||||
prompt_length = len(input_ids) # default: all masked
|
||||
for i, lbl in enumerate(labels):
|
||||
if lbl != -100:
|
||||
prompt_length = i
|
||||
break
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
@@ -241,6 +241,23 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Access logging middleware ---
|
||||
import time as _time
|
||||
|
||||
@app.middleware("http")
|
||||
async def access_log_middleware(request, call_next):
|
||||
t0 = _time.monotonic()
|
||||
response = await call_next(request)
|
||||
elapsed = _time.monotonic() - t0
|
||||
logger.info(
|
||||
"%s %s %d %.3fs",
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
elapsed,
|
||||
)
|
||||
return response
|
||||
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
@@ -300,7 +317,11 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
import vllm
|
||||
from packaging.version import Version
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
except ImportError:
|
||||
GuidedDecodingParams = None # not available in vLLM 0.17+
|
||||
|
||||
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
|
||||
prompts: list[dict[str, Any]] = []
|
||||
@@ -362,7 +383,12 @@ def main(script_args: ScriptArguments):
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
# Use run_in_executor so blocking recv() doesn't freeze the event loop
|
||||
# (allows /set_lora_adapter/ and other endpoints to be served concurrently)
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
|
||||
]
|
||||
@@ -404,7 +430,10 @@ def main(script_args: ScriptArguments):
|
||||
}
|
||||
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
@@ -474,11 +503,51 @@ def main(script_args: ScriptArguments):
|
||||
)
|
||||
return {"message": f"Batch update for {len(params_list)} params"}
|
||||
|
||||
class HTTPWeightUpdateRequest(BaseModel):
|
||||
"""Weight update via HTTP (no NCCL needed)."""
|
||||
|
||||
params: list[
|
||||
dict
|
||||
] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
|
||||
|
||||
@app.post("/http_update_weights/")
|
||||
async def http_update_weights(request: HTTPWeightUpdateRequest):
|
||||
"""Update model weights via HTTP — no NCCL communicator required.
|
||||
|
||||
Tensor data is sent as base64-encoded raw bytes in the request body.
|
||||
Slower than NCCL for large models but works without cross-process setup.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import (
|
||||
decode_from_http,
|
||||
encode_for_ipc,
|
||||
)
|
||||
|
||||
weights_to_load = [decode_from_http(p) for p in request.params]
|
||||
|
||||
# Send all weights in a single IPC call. Tensors don't survive
|
||||
# vLLM's multiproc IPC, so serialize as raw bytes + metadata.
|
||||
param_entries = [
|
||||
encode_for_ipc(name, weight) for name, weight in weights_to_load
|
||||
]
|
||||
kwargs = {
|
||||
"method": "http_load_weights_batch",
|
||||
"kwargs": {"params": param_entries},
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
|
||||
|
||||
@app.post("/reset_prefix_cache/")
|
||||
async def reset_prefix_cache():
|
||||
for conn in connections:
|
||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
||||
results = [conn.recv() for conn in connections]
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
|
||||
@@ -51,6 +51,19 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
model = self.model_runner.model
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
# Handle VLM models where trainer and vLLM use different prefixes.
|
||||
# Trainer (PEFT stripped): "model.layers.X..." or "model.language_model.layers.X..."
|
||||
# vLLM (Qwen3.5): "language_model.model.layers.X..."
|
||||
if name not in params_dict:
|
||||
# Try common prefix remappings
|
||||
for src_prefix, dst_prefix in [
|
||||
("model.language_model.layers.", "language_model.model.layers."),
|
||||
("model.layers.", "language_model.model.layers."),
|
||||
]:
|
||||
if name.startswith(src_prefix):
|
||||
name = dst_prefix + name[len(src_prefix) :]
|
||||
break
|
||||
|
||||
# Check if this is a simple direct param (exists as-is)
|
||||
if name in params_dict:
|
||||
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
|
||||
@@ -106,7 +119,15 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
return
|
||||
|
||||
# Fallback: try load_weights (may work for non-stacked params)
|
||||
logger.warning("Falling back to load_weights for param: %s", name)
|
||||
# Log the actual param names available for debugging
|
||||
sample_keys = [
|
||||
k for k in params_dict if "layers.31.mlp" in k or "layers.31.self_attn" in k
|
||||
][:3]
|
||||
logger.warning(
|
||||
"Falling back to load_weights for param: %s (sample vLLM keys: %s)",
|
||||
name,
|
||||
sample_keys,
|
||||
)
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def update_named_param(self, name, dtype, shape):
|
||||
@@ -156,3 +177,32 @@ class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
# Load weights using direct set (handles stacked params)
|
||||
for name, weight in weights_to_load:
|
||||
self._direct_set_weight(name, weight)
|
||||
|
||||
def http_load_weights(self, weights: list[tuple[str, torch.Tensor]]):
|
||||
"""Load weights received via HTTP (no NCCL needed)."""
|
||||
for name, weight in weights:
|
||||
self._direct_set_weight(name, weight.to(self.device))
|
||||
|
||||
def http_load_weight(self, **kwargs):
|
||||
"""Load a single weight received via HTTP (no NCCL needed).
|
||||
|
||||
Reconstructs the tensor from raw bytes since tensors don't survive
|
||||
vLLM's multiproc IPC serialization. Uses vLLM's ``load_weights``
|
||||
which handles TP sharding and stacked-param packing automatically.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import decode_from_ipc
|
||||
|
||||
name, weight = decode_from_ipc(kwargs)
|
||||
model = self.model_runner.model
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def http_load_weights_batch(self, params: list[dict]):
|
||||
"""Load multiple weights in a single IPC call.
|
||||
|
||||
Uses vLLM's ``load_weights`` which handles TP sharding automatically.
|
||||
"""
|
||||
from axolotl.utils.weight_serde import decode_from_ipc
|
||||
|
||||
model = self.model_runner.model
|
||||
weights = [decode_from_ipc(p) for p in params]
|
||||
model.load_weights(weights=weights)
|
||||
|
||||
@@ -138,7 +138,11 @@ def setup_reference_model(
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
reference_model: bool = True
|
||||
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if (
|
||||
cfg.rl in {RLType.GRPO, RLType.EBFT}
|
||||
and getattr(trl_cfg, "beta", 0) == 0
|
||||
):
|
||||
reference_model = False
|
||||
# load the model again for model_ref/baseline
|
||||
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
|
||||
@@ -206,7 +210,7 @@ def execute_training(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
gather_outputs=cfg.rl in {RLType.GRPO, RLType.EBFT},
|
||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -691,8 +691,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
].append(pred_step_text)
|
||||
row_index += 1
|
||||
if logger == "wandb":
|
||||
# type: ignore[attr-defined]
|
||||
wandb.run.log(
|
||||
wandb.run.log( # type: ignore[attr-defined]
|
||||
{
|
||||
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
|
||||
table_data
|
||||
@@ -748,12 +747,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
artifact = wandb.Artifact(
|
||||
f"config-{wandb.run.id}", type="axolotl-config"
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"config-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="axolotl-config",
|
||||
)
|
||||
artifact.add_file(temp_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the WandB run under files."
|
||||
)
|
||||
@@ -779,12 +779,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
temp_ct_file.write(str(chat_tpl))
|
||||
temp_ct_file.flush()
|
||||
|
||||
artifact = wandb.Artifact(
|
||||
f"chat-template-{wandb.run.id}", type="jinja-template"
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"chat-template-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="jinja-template",
|
||||
)
|
||||
artifact.add_file(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_ct_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The chat_template_jinja has been saved to the WandB run under files."
|
||||
)
|
||||
@@ -810,13 +811,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
else:
|
||||
skip_upload = True
|
||||
if not skip_upload:
|
||||
artifact = wandb.Artifact(
|
||||
f"deepspeed-config-{wandb.run.id}",
|
||||
artifact = wandb.Artifact( # type: ignore[attr-defined]
|
||||
f"deepspeed-config-{wandb.run.id}", # type: ignore[attr-defined]
|
||||
type="deepspeed-config",
|
||||
)
|
||||
artifact.add_file(temp_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_file.name)
|
||||
wandb.log_artifact(artifact) # type: ignore[attr-defined]
|
||||
wandb.save(temp_file.name) # type: ignore[attr-defined]
|
||||
LOG.info(
|
||||
"The DeepSpeed config has been saved to the WandB run under files."
|
||||
)
|
||||
|
||||
@@ -28,36 +28,36 @@ class SFTGenerationCallback(TrainerCallback):
|
||||
if not getattr(cfg, "generate_samples", False):
|
||||
return
|
||||
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
LOG.info(
|
||||
f"Using eval dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||
dataloader = None
|
||||
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
LOG.info(
|
||||
f"Using train dataloader for generation at step {state.global_step}"
|
||||
f"Using eval dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||
dataloader = None
|
||||
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||
top_p=getattr(cfg, "generation_top_p", None),
|
||||
top_k=getattr(cfg, "generation_top_k", None),
|
||||
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
LOG.info(
|
||||
f"Using train dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||
top_p=getattr(cfg, "generation_top_p", None),
|
||||
top_k=getattr(cfg, "generation_top_k", None),
|
||||
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||
)
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
def _log_samples(self, samples: list, step: int):
|
||||
"""Log generated samples to console and W&B."""
|
||||
@@ -71,10 +71,10 @@ class SFTGenerationCallback(TrainerCallback):
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
if wandb.run is not None: # type: ignore[attr-defined]
|
||||
wandb.log( # type: ignore[attr-defined]
|
||||
{
|
||||
f"samples/sample_{i + 1}": wandb.Html(
|
||||
f"samples/sample_{i + 1}": wandb.Html( # type: ignore[attr-defined]
|
||||
f"<pre>{wandb_text}</pre>"
|
||||
)
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
@@ -173,7 +174,7 @@ def _drop_long_sequences(
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
if rl in {RLType.GRPO, RLType.GDPO}:
|
||||
if rl in {RLType.GRPO, RLType.GDPO, RLType.EBFT}:
|
||||
return True
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
@@ -209,12 +210,30 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.EBFT:
|
||||
ds_transform_fn = load_ebft(_type, cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs: dict[str, Any] = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
# Handle remove_columns: "__all__" removes all original columns,
|
||||
# or filter a list to only columns that exist in the dataset
|
||||
if "remove_columns" in map_kwargs:
|
||||
ds_columns = (
|
||||
dataset.column_names
|
||||
if isinstance(dataset, Dataset)
|
||||
else dataset[split].column_names
|
||||
if isinstance(dataset, DatasetDict)
|
||||
else []
|
||||
)
|
||||
if map_kwargs["remove_columns"] == "__all__":
|
||||
map_kwargs["remove_columns"] = list(ds_columns)
|
||||
else:
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
split_datasets[i] = _map_dataset(
|
||||
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
|
||||
@@ -55,6 +55,119 @@ from axolotl.utils.schemas.vllm import VllmConfig
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EBFTConfig(BaseModel):
|
||||
"""Configuration for Energy-Based Fine-Tuning (EBFT)"""
|
||||
|
||||
feature_layers: list[float] = Field(
|
||||
default=[0.25, 0.5, 0.75],
|
||||
json_schema_extra={
|
||||
"description": "Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])"
|
||||
},
|
||||
)
|
||||
embed_method: Literal["last_token", "mean_pooling", "completion_mean", "concat"] = (
|
||||
Field(
|
||||
default="last_token",
|
||||
json_schema_extra={
|
||||
"description": "Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'"
|
||||
},
|
||||
)
|
||||
)
|
||||
use_whitening: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
|
||||
)
|
||||
alignment_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient for alignment reward (cosine similarity with ground truth)"
|
||||
},
|
||||
)
|
||||
diversity_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient for diversity penalty (pairwise similarity between samples)"
|
||||
},
|
||||
)
|
||||
ce_coef: float = Field(
|
||||
default=0.0,
|
||||
json_schema_extra={
|
||||
"description": "Cross-entropy loss coefficient on ground-truth tokens"
|
||||
},
|
||||
)
|
||||
adaptive_max_tokens: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Set per-batch max_tokens based on ground-truth length"
|
||||
},
|
||||
)
|
||||
gt_length_multiplier: float = Field(
|
||||
default=1.5,
|
||||
ge=0.1,
|
||||
json_schema_extra={
|
||||
"description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
|
||||
},
|
||||
)
|
||||
|
||||
# Strided mode fields (for unstructured text)
|
||||
mode: Literal["structured", "strided"] = Field(
|
||||
default="structured",
|
||||
json_schema_extra={
|
||||
"description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
|
||||
},
|
||||
)
|
||||
stride: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Stride between anchor points (tokens)"},
|
||||
)
|
||||
context_length: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Context window size per block"},
|
||||
)
|
||||
generate_max_len: int = Field(
|
||||
default=8,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Tokens to generate per block"},
|
||||
)
|
||||
n_samples_per_prompt: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
json_schema_extra={"description": "Independent rollouts per document"},
|
||||
)
|
||||
temperature: float = Field(
|
||||
default=0.6,
|
||||
ge=0.0,
|
||||
json_schema_extra={
|
||||
"description": "Sampling temperature for strided generation"
|
||||
},
|
||||
)
|
||||
top_p: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
json_schema_extra={"description": "Top-p nucleus sampling threshold"},
|
||||
)
|
||||
rl_coef: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={"description": "RL policy gradient loss coefficient"},
|
||||
)
|
||||
advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
|
||||
default="rloo",
|
||||
json_schema_extra={
|
||||
"description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
|
||||
},
|
||||
)
|
||||
min_completion_prefix: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"description": "Minimum tokens into completion before placing anchors. "
|
||||
"Skips anchors too close to the prompt boundary where features are dominated by prompt context."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
ModelOutputConfig,
|
||||
@@ -131,7 +244,7 @@ class AxolotlInputConfig(
|
||||
rl: RLType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'"
|
||||
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'"
|
||||
},
|
||||
)
|
||||
trl: TRLConfig | None = Field(
|
||||
@@ -140,6 +253,12 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
ebft: EBFTConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
|
||||
},
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
|
||||
@@ -35,6 +35,7 @@ class RLType(str, Enum):
|
||||
ORPO = "orpo"
|
||||
KTO = "kto"
|
||||
SIMPO = "simpo"
|
||||
EBFT = "ebft"
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pydantic models for TRL trainer configuration"""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -133,6 +133,20 @@ class TRLConfig(BaseModel):
|
||||
"description": "Penalty for tokens that appear in prompt and generated text."
|
||||
},
|
||||
)
|
||||
generation_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Additional generation parameters passed to vLLM SamplingParams. "
|
||||
"Useful for stop_token_ids, seed, frequency_penalty, etc."
|
||||
},
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Additional kwargs for the chat template. "
|
||||
"E.g., {enable_thinking: false} for Qwen3.5 models."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1482,6 +1482,124 @@ class DistributedValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
class EBFTValidationMixin:
|
||||
"""Validation for EBFT (Energy-Based Fine-Tuning) configuration."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_config_required(cls, data):
|
||||
"""rl: ebft requires an ebft config section."""
|
||||
if data.get("rl") == "ebft" and not data.get("ebft"):
|
||||
raise ValueError(
|
||||
"`ebft` config section is required when `rl: ebft` is set. "
|
||||
"Add an `ebft:` section with at least `mode: structured` or `mode: strided`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_torch_compile(cls, data):
|
||||
"""torch_compile + flex_attention + gradient_checkpointing causes dynamo recompiles
|
||||
and CheckpointErrors. The flex_attention kernel compiles itself internally —
|
||||
whole-model torch.compile is not needed and actively harmful."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("torch_compile") is True
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
):
|
||||
if data.get("gradient_checkpointing"):
|
||||
raise ValueError(
|
||||
"EBFT strided mode: `torch_compile: true` with `gradient_checkpointing: true` "
|
||||
"causes CheckpointError (BlockMask metadata mismatch during recomputation). "
|
||||
"Remove `torch_compile` — the flex_attention kernel compiles itself internally."
|
||||
)
|
||||
LOG.warning(
|
||||
"EBFT strided mode: `torch_compile: true` causes dynamo recompiles from "
|
||||
"variable sequence lengths across steps. Consider removing it — "
|
||||
"flex_attention compiles itself internally."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_gradient_checkpointing_reentrant(cls, data):
|
||||
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("flex_attention")
|
||||
and data.get("gradient_checkpointing")
|
||||
):
|
||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
||||
if not gc_kwargs.get("use_reentrant"):
|
||||
LOG.warning(
|
||||
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
|
||||
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
|
||||
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
|
||||
)
|
||||
if data.get("gradient_checkpointing_kwargs") is None:
|
||||
data["gradient_checkpointing_kwargs"] = {}
|
||||
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_activation_offloading(cls, data):
|
||||
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
|
||||
which conflicts with flex_attention's use_reentrant requirement."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("activation_offloading") is True
|
||||
and data.get("flex_attention")
|
||||
):
|
||||
raise ValueError(
|
||||
"EBFT strided mode: `activation_offloading: true` is incompatible with "
|
||||
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
|
||||
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
|
||||
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
|
||||
"uses micro-batched forward passes for memory efficiency instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_strided_sequence_len(cls, data):
|
||||
"""Warn if sequence_len is too large for single-GPU strided EBFT."""
|
||||
if data.get("rl") != "ebft" or data.get("ebft", {}).get("mode") != "strided":
|
||||
return data
|
||||
ebft = data.get("ebft", {})
|
||||
seq_len = data.get("sequence_len", 512)
|
||||
n_samples = ebft.get("n_samples_per_prompt", 4)
|
||||
gen_len = ebft.get("generate_max_len", 8)
|
||||
stride = ebft.get("stride", 8)
|
||||
ctx_len = ebft.get("context_length", 8)
|
||||
max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
|
||||
full_seq = seq_len + max_blocks * gen_len
|
||||
# Rough estimate: 8.7 GB per sample at S=3900 for 1B model
|
||||
if full_seq * n_samples > 20000:
|
||||
LOG.warning(
|
||||
f"EBFT strided: full_seq_len={full_seq} * n_samples={n_samples} = "
|
||||
f"{full_seq * n_samples} token-samples per step. This may require >24GB VRAM "
|
||||
f"for a 1B+ model. Consider reducing sequence_len, n_samples_per_prompt, or stride."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_strided_dataset_split(cls, data):
|
||||
"""Warn about the common `train_on_split` mistake (silently ignored by schema)."""
|
||||
datasets = data.get("datasets", [])
|
||||
for ds in datasets or []:
|
||||
if isinstance(ds, dict) and ds.get("train_on_split"):
|
||||
LOG.warning(
|
||||
f"Dataset has `train_on_split: {ds['train_on_split']}` — this field "
|
||||
f"is not recognized and will be silently ignored. "
|
||||
f"Use `split: {ds['train_on_split']}` instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class GRPOVllmValidationMixin:
|
||||
"""Validation mixin for vllm when using GRPO."""
|
||||
|
||||
@@ -1507,6 +1625,7 @@ class ValidationMixin(
|
||||
PretrainingValidationMixin,
|
||||
ModelCompatibilityValidationMixin,
|
||||
ComplexValidationMixin,
|
||||
EBFTValidationMixin,
|
||||
GRPOVllmValidationMixin,
|
||||
):
|
||||
"""Full validation mixin for Axolotl configuration."""
|
||||
|
||||
@@ -57,6 +57,13 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Reasoning parser for VLLM"},
|
||||
)
|
||||
enforce_eager: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Disable CUDA graph capture in vLLM. Required for models with "
|
||||
"causal_conv1d (e.g., Qwen3.5 hybrid linear attention)."
|
||||
},
|
||||
)
|
||||
serve_module: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
94
src/axolotl/utils/weight_serde.py
Normal file
94
src/axolotl/utils/weight_serde.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Serialize / deserialize tensors for HTTP and IPC weight sync.
|
||||
|
||||
NumPy doesn't support bfloat16, so bf16 tensors are cast to fp16 on the wire
|
||||
and reconstructed at the destination. All encode/decode helpers live here so
|
||||
the logic isn't duplicated across trl_vllm.py, vllm_serve_lora.py, and
|
||||
vllm_worker_ext.py.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def encode_for_http(name: str, weight: torch.Tensor) -> dict:
|
||||
"""Encode a named parameter for JSON transport over HTTP.
|
||||
|
||||
Returns a dict with keys: name, dtype (original), shape, data (base64).
|
||||
bf16 tensors are sent as fp16 bytes; the original dtype is preserved in
|
||||
the ``dtype`` field so the receiver can cast back.
|
||||
"""
|
||||
w_cpu = weight.contiguous().cpu()
|
||||
orig_dtype = str(weight.dtype)
|
||||
if w_cpu.dtype == torch.bfloat16:
|
||||
w_cpu = w_cpu.half()
|
||||
raw = w_cpu.numpy().tobytes()
|
||||
return {
|
||||
"name": name,
|
||||
"dtype": orig_dtype,
|
||||
"shape": list(weight.shape),
|
||||
"data": base64.b64encode(raw).decode("ascii"),
|
||||
}
|
||||
|
||||
|
||||
def decode_from_http(entry: dict) -> tuple[str, torch.Tensor]:
|
||||
"""Decode an HTTP-encoded weight entry back to a named tensor.
|
||||
|
||||
Infers wire dtype from byte count (bf16 arrives as fp16) and casts to the
|
||||
original dtype stored in ``entry["dtype"]``.
|
||||
"""
|
||||
target_dtype = getattr(torch, entry["dtype"].split(".")[-1])
|
||||
shape = tuple(entry["shape"])
|
||||
raw = base64.b64decode(entry["data"])
|
||||
|
||||
n_elements = 1
|
||||
for s in shape:
|
||||
n_elements *= s
|
||||
wire_bytes_per_elem = len(raw) // max(n_elements, 1)
|
||||
if wire_bytes_per_elem == 2:
|
||||
wire_dtype = torch.float16
|
||||
elif wire_bytes_per_elem == 4:
|
||||
wire_dtype = torch.float32
|
||||
else:
|
||||
wire_dtype = target_dtype
|
||||
|
||||
weight = torch.frombuffer(bytearray(raw), dtype=wire_dtype).reshape(shape)
|
||||
if wire_dtype != target_dtype:
|
||||
weight = weight.to(target_dtype)
|
||||
return entry["name"], weight
|
||||
|
||||
|
||||
def encode_for_ipc(name: str, weight: torch.Tensor) -> dict:
|
||||
"""Encode a tensor for vLLM's multiproc IPC (raw bytes, no base64).
|
||||
|
||||
Returns a dict with keys: name, data (bytes), dtype (wire), target_dtype
|
||||
(original), shape. bf16 tensors are serialized as fp16.
|
||||
"""
|
||||
w = weight.contiguous()
|
||||
target_dtype = str(w.dtype).split(".")[-1]
|
||||
if w.dtype == torch.bfloat16:
|
||||
w = w.half()
|
||||
wire_dtype = str(w.dtype).split(".")[-1]
|
||||
return {
|
||||
"name": name,
|
||||
"data": w.numpy().tobytes(),
|
||||
"dtype": wire_dtype,
|
||||
"target_dtype": target_dtype,
|
||||
"shape": list(weight.shape),
|
||||
}
|
||||
|
||||
|
||||
def decode_from_ipc(entry: dict) -> tuple[str, torch.Tensor]:
|
||||
"""Decode an IPC-encoded weight entry back to a named tensor.
|
||||
|
||||
Handles optional ``target_dtype`` for backward compatibility with older
|
||||
serve code that may not include it.
|
||||
"""
|
||||
wire_dtype = getattr(torch, entry["dtype"])
|
||||
weight = torch.frombuffer(bytearray(entry["data"]), dtype=wire_dtype).reshape(
|
||||
entry["shape"]
|
||||
)
|
||||
target_dtype = entry.get("target_dtype")
|
||||
if target_dtype and target_dtype != entry["dtype"]:
|
||||
weight = weight.to(getattr(torch, target_dtype))
|
||||
return entry["name"], weight
|
||||
Reference in New Issue
Block a user