Compare commits

...

2 Commits

Author SHA1 Message Date
Dan Saunders
64f349b7bb diffusion alt: custom loss impl 2025-08-18 20:50:34 +00:00
Dan Saunders
260ebe4c93 diffusion alt: custom loss impl 2025-08-18 20:50:20 +00:00
5 changed files with 575 additions and 471 deletions

View File

@@ -0,0 +1,115 @@
"""Diffusion LM loss function for integration with transformers LOSS_MAPPING."""
from typing import Optional
import torch
import torch.nn.functional as F
def ForDiffusionLMLoss(
logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
config: Optional[dict] = None,
inputs: Optional[dict] = None,
model: Optional[torch.nn.Module] = None,
**kwargs,
) -> torch.Tensor:
"""
Diffusion Language Modeling loss function.
This function computes cross-entropy loss only on masked tokens using
diffusion info stored by the model patch during forward pass.
Args:
logits: Model predictions [batch_size, seq_len, vocab_size]
labels: Ground truth tokens [batch_size, seq_len]
vocab_size: Size of vocabulary
config: Model configuration (contains diffusion parameters)
inputs: Input batch dictionary (contains input_ids, attention_mask)
model: The model instance (to access stored diffusion info)
**kwargs: Additional arguments
Returns:
loss: Computed diffusion loss
"""
# Get diffusion info stored by model patch
if model is None or not hasattr(model, "_diffusion_info"):
# Fallback to regular causal LM loss if no diffusion info
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
return loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
diffusion_info = model._diffusion_info
original_input_ids = diffusion_info["original_input_ids"]
masked_indices = diffusion_info["masked_indices"]
p_mask = diffusion_info["p_mask"]
# Get diffusion config parameters
diffusion_config = getattr(config, "diffusion_config", {})
importance_weighting = diffusion_config.get("importance_weighting", True)
# Check if we have any masked tokens
if not masked_indices.any():
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Get predictions and targets for masked positions only
masked_logits = logits[masked_indices]
masked_targets = original_input_ids[masked_indices] # Original unmasked tokens
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if importance_weighting:
# Apply importance weighting: 1 / p_mask
masked_p_mask = p_mask.expand_as(masked_indices)[masked_indices]
weighted_loss = token_loss / masked_p_mask
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float()
# Group losses by batch sample
batch_indices = torch.arange(
original_input_ids.shape[0], device=original_input_ids.device
)
batch_indices = batch_indices.unsqueeze(1).expand_as(masked_indices)
masked_batch_indices = batch_indices[masked_indices]
# Sum losses per sample and normalize by answer length
loss_per_sample = torch.zeros(
original_input_ids.shape[0], device=original_input_ids.device
)
for i in range(original_input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.any():
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / max(answer_lengths[i], 1)
loss = loss_per_sample.mean()
else:
# For completion data: simple average
loss = weighted_loss.mean()
else:
# No importance weighting
loss = token_loss.mean()
return loss
def register_diffusion_loss():
"""Register the diffusion loss function in transformers LOSS_MAPPING."""
try:
from transformers.loss.loss_utils import LOSS_MAPPING
LOSS_MAPPING["ForDiffusionLM"] = ForDiffusionLMLoss
return True
except ImportError:
# Fallback for older transformers versions
return False

View File

@@ -0,0 +1,149 @@
"""Model patches for diffusion training."""
import torch
def patch_model_for_bidirectional_attention(model):
"""
Patch model to handle diffusion training with forward process and bidirectional
attention.
This monkey-patches the model's forward method to:
- Apply forward diffusion process (masking) during training
- Use bidirectional attention masks
- Store info for loss computation
"""
original_forward = model.forward
def diffusion_forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs,
):
# Check if this is diffusion training
if (
hasattr(self.config, "loss_type")
and self.config.loss_type == "ForDiffusionLM"
and self.training
):
# Store original input_ids for loss computation
original_input_ids = input_ids.clone()
# Apply forward diffusion process (masking)
diffusion_config = getattr(self.config, "diffusion_config", {})
noisy_input_ids, masked_indices, p_mask = _forward_process(
input_ids, attention_mask, labels, diffusion_config
)
# Use noisy input for model forward
input_ids = noisy_input_ids
# Convert attention mask to bidirectional
if attention_mask is not None:
attention_mask = _create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Store diffusion info in the model for loss computation
self._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
return original_forward(
input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
)
# Replace the forward method
model.forward = diffusion_forward.__get__(model, model.__class__)
def _create_bidirectional_attention_mask(
input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Create bidirectional attention mask from 2D attention mask.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: 2D attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
"""
batch_size, seq_len = input_ids.shape
# Simple bidirectional mask - all tokens can attend to all valid tokens
# Expand 2D mask to 4D: [batch_size, seq_len] -> [batch_size, 1, seq_len, seq_len]
bidirectional_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
bidirectional_mask = bidirectional_mask.expand(batch_size, 1, seq_len, seq_len)
# Apply row-wise masking (padded tokens can't attend to anything)
row_mask = attention_mask.unsqueeze(1).unsqueeze(3) # [B, 1, S, 1]
bidirectional_mask = bidirectional_mask & row_mask
return bidirectional_mask
def _forward_process(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
diffusion_config: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply forward diffusion process (random masking).
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Labels for SFT training [batch_size, seq_len]
diffusion_config: Diffusion configuration dict
Returns:
noisy_input_ids: Input with masked tokens
masked_indices: Boolean mask of which tokens were masked
p_mask: Masking probabilities used
"""
if diffusion_config is None:
diffusion_config = {}
batch_size, seq_len = input_ids.shape
device = input_ids.device
eps = diffusion_config.get("eps", 1e-3)
mask_token_id = diffusion_config.get("mask_token_id", 128002)
# Sample random timesteps for each sample
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask.unsqueeze(1).expand(-1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens
if attention_mask is not None:
p_mask = p_mask * attention_mask.float()
# Create random mask based on p_mask
random_values = torch.rand_like(p_mask)
masked_indices = random_values < p_mask
# Apply attention mask constraints
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens (where labels != -100)
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create noisy input by replacing masked tokens
noisy_input_ids = input_ids.clone()
noisy_input_ids[masked_indices] = mask_token_id
return noisy_input_ids, masked_indices, p_mask

View File

@@ -7,7 +7,10 @@ from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer from .args import DiffusionArgs
from .callbacks import DiffusionGenerationCallback
from .loss import register_diffusion_loss
from .model_patch import patch_model_for_bidirectional_attention
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -24,18 +27,70 @@ class DiffusionPlugin(BasePlugin):
super().__init__() super().__init__()
self.cfg = None self.cfg = None
if register_diffusion_loss():
LOG.info("Registered ForDiffusionLM loss function")
else:
LOG.warning(
"Failed to register diffusion loss - older transformers version"
)
def get_input_args(self) -> str: def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded.""" """Configure model for diffusion training after loading."""
self.cfg = cfg self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: # Set loss type for diffusion training
"""Return custom trainer class for diffusion training.""" if hasattr(model, "config"):
return DiffusionTrainer model.config.loss_type = "ForDiffusionLM"
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): # Store diffusion config in model config
model.config.diffusion_config = {
"eps": getattr(cfg, "eps", 1e-3),
"importance_weighting": getattr(cfg, "importance_weighting", True),
"mask_token_id": getattr(cfg, "mask_token_id", 128002),
}
LOG.info("Configured model for diffusion training with ForDiffusionLM loss")
# Patch model for bidirectional attention during training
patch_model_for_bidirectional_attention(model)
LOG.info("Applied bidirectional attention patch to model")
return model
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Configure trainer after creation.""" """Configure trainer after creation."""
trainer.set_config(cfg) # Create diffusion config from cfg
diffusion_config = DiffusionArgs(
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.1),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 0.9),
num_diffusion_steps=getattr(cfg, "num_diffusion_steps", 128),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", True),
mask_token_id=getattr(cfg, "mask_token_id", 128002),
generate_samples=getattr(cfg, "generate_samples", True),
generation_interval=getattr(cfg, "generation_interval", 100),
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
generation_steps=getattr(cfg, "generation_steps", 128),
generation_temperature=getattr(cfg, "generation_temperature", 0.0),
generation_max_length=getattr(cfg, "generation_max_length", 100),
)
# Store diffusion config on trainer for callbacks to access
trainer.diffusion_config = diffusion_config
LOG.info("Stored diffusion config on trainer")
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
"""Add diffusion generation callback if enabled."""
if (
hasattr(trainer, "diffusion_config")
and trainer.diffusion_config.generate_samples
):
generation_callback = DiffusionGenerationCallback(trainer)
LOG.info("Added diffusion generation callback")
return [generation_callback]
return []

View File

@@ -1,279 +0,0 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._cache_special_token_ids()
if config.generate_samples:
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
@torch.compile
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
@torch.compile
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass
outputs = model(
input_ids=noisy_batch,
attention_mask=bidirectional_mask,
)
logits = outputs.logits
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
"avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(),
}
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -2,111 +2,180 @@
# pylint: disable=redefined-outer-name,protected-access # pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from axolotl.integrations.diffusion.trainer import DiffusionTrainer from axolotl.integrations.diffusion.args import DiffusionArgs
from axolotl.utils.dict import DictDefault from axolotl.integrations.diffusion.loss import (
ForDiffusionLMLoss,
register_diffusion_loss,
@pytest.fixture )
def mock_tokenizer(): from axolotl.integrations.diffusion.model_patch import (
"""Create a mock tokenizer.""" _create_bidirectional_attention_mask,
tokenizer = Mock() _forward_process,
tokenizer.bos_token_id = 1 patch_model_for_bidirectional_attention,
tokenizer.eos_token_id = 2 )
tokenizer.pad_token_id = 0 from axolotl.integrations.diffusion.plugin import DiffusionPlugin
return tokenizer
@pytest.fixture @pytest.fixture
def diffusion_config(): def diffusion_config():
"""Create a diffusion config.""" """Create a diffusion config."""
return DictDefault( return DiffusionArgs(
{ eps=1e-3,
"mask_token_id": 32000, importance_weighting=False,
"eps": 1e-3, mask_token_id=32000,
"importance_weighting": False, generate_samples=False,
"sample_packing": False,
}
) )
@pytest.fixture @pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config): def mock_model():
"""Create a diffusion trainer instance for testing methods directly.""" """Create a mock model."""
# Create a minimal trainer instance just for testing methods model = Mock()
trainer = object.__new__(DiffusionTrainer) # Bypass __init__ model.config = Mock()
trainer.config = diffusion_config model.config.loss_type = "ForDiffusionLM"
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos model.config.diffusion_config = {
trainer.processing_class = mock_tokenizer "eps": 1e-3,
trainer.store_metrics = Mock() # Mock metrics storage "importance_weighting": False,
return trainer "mask_token_id": 32000,
}
model.training = True
return model
class TestDiffusionTrainer: class TestDiffusionLoss:
"""Test the DiffusionTrainer class.""" """Test the ForDiffusionLMLoss function."""
def test_forward_process_basic(self, diffusion_trainer_instance): def test_loss_with_diffusion_info(self, mock_model):
"""Test basic forward process without labels.""" """Test loss computation with stored diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Mock stored diffusion info
original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
masked_indices = torch.tensor(
[[False, True, True, False, False]], dtype=torch.bool
)
p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
# Mock logits
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert loss.item() >= 0
def test_loss_fallback_without_diffusion_info(self, mock_model):
"""Test fallback to causal LM loss when no diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Remove diffusion info to trigger fallback
if hasattr(mock_model, "_diffusion_info"):
delattr(mock_model, "_diffusion_info")
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_loss_no_masked_tokens(self, mock_model):
"""Test loss when no tokens are masked."""
batch_size, seq_len, vocab_size = 1, 3, 1000
# No masked tokens
original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long)
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.tensor([[1, 10, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert loss.item() == 0.0
class TestModelPatch:
"""Test the model patching functionality."""
def test_forward_process_basic(self):
"""Test basic forward process."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = ( noisy_input_ids, masked_indices, p_mask = _forward_process(
diffusion_trainer_instance._forward_process(input_ids, eps=0.1) input_ids, diffusion_config=diffusion_config
) )
# Check shapes # Check shapes
assert noisy_batch.shape == input_ids.shape assert noisy_input_ids.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked # Check that mask token is applied where masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) if masked_indices.any():
assert not masked_indices[special_token_positions].any() assert (noisy_input_ids[masked_indices] == 32000).all()
# Check that mask token is applied def test_forward_process_with_labels(self):
mask_token_id = diffusion_trainer_instance._config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance):
"""Test forward process with SFT labels.""" """Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = ( _, masked_indices, _ = _forward_process(
diffusion_trainer_instance._forward_process( input_ids, labels=labels, diffusion_config=diffusion_config
input_ids, labels=labels, eps=0.1
)
) )
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100) # Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100 non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any() assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep), def test_forward_process_with_attention_mask(self):
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
"""Test forward process with attention mask.""" """Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process( _, masked_indices, p_mask = _forward_process(
input_ids, attention_mask=attention_mask, eps=0.1 input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config
) )
# Check that padding tokens are not masked # Check that padding tokens are not masked
@@ -114,158 +183,153 @@ class TestDiffusionTrainer:
assert not masked_indices[padding_positions].any() assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all() assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): def test_bidirectional_attention_mask(self):
"""Test bidirectional attention mask without sample packing.""" """Test bidirectional attention mask creation."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask( mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
input_ids
)
# Should be all-to-all attention # Should be all-to-all attention
expected_shape = (1, 1, 4, 4) expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape assert mask.shape == expected_shape
assert mask.all() assert mask.all()
def test_bidirectional_attention_mask_with_packing( def test_bidirectional_attention_mask_with_padding(self):
self, diffusion_trainer_instance """Test bidirectional attention mask with padding."""
): input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
"""Test bidirectional attention mask with sample packing.""" attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
diffusion_trainer_instance._config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask( mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
input_ids, attention_mask
)
# Check that tokens within same sample can attend to each other # Padding positions should not attend or be attended to
# but not across samples assert not mask[0, 0, 3, :].any() # Padding can't attend to anything
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other assert not mask[0, 0, :, 3].any() # Nothing can attend to padding
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer_instance): def test_patch_model_for_bidirectional_attention(self):
"""Test basic loss computation.""" """Test that model patching works."""
# Mock model that returns logits
mock_model = Mock() mock_model = Mock()
mock_outputs = Mock() mock_model.config = Mock()
vocab_size = 1000 mock_model.config.loss_type = "ForDiffusionLM"
seq_len = 5 mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000}
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) original_forward = Mock()
mock_model.forward = original_forward
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( # Patch the model
mock_model, input_ids patch_model_for_bidirectional_attention(mock_model)
)
# Check that loss is computed # Check that forward method was replaced
assert isinstance(loss, torch.Tensor) assert mock_model.forward != original_forward
assert loss.requires_grad
assert outputs == mock_outputs
# Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance): class TestDiffusionPlugin:
"""Test loss computation with SFT labels.""" """Test the DiffusionPlugin."""
# Mock model
def test_plugin_registers_loss_function(self):
"""Test that plugin registers diffusion loss function."""
with patch(
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
return_value=True,
) as mock_register:
plugin = DiffusionPlugin()
mock_register.assert_called_once()
def test_post_model_load_configuration(self):
"""Test that post_model_load configures model correctly."""
plugin = DiffusionPlugin()
# Mock model and config
mock_model = Mock() mock_model = Mock()
mock_outputs = Mock() mock_model.config = Mock()
vocab_size = 1000 mock_cfg = Mock()
seq_len = 5 mock_cfg.eps = 1e-3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) mock_cfg.importance_weighting = True
mock_model.return_value = mock_outputs mock_cfg.mask_token_id = 32000
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) with patch(
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) "axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
) as mock_patch:
result = plugin.post_model_load(mock_cfg, mock_model)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss( # Check model configuration
mock_model, input_ids, labels=labels assert mock_model.config.loss_type == "ForDiffusionLM"
) assert mock_model.config.diffusion_config is not None
assert mock_model.config.diffusion_config["eps"] == 1e-3
# Check that loss is computed # Check model was patched
assert isinstance(loss, torch.Tensor) mock_patch.assert_called_once_with(mock_model)
assert loss.requires_grad
# Check that SFT metrics were added # Should return the model
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] assert result == mock_model
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): def test_post_trainer_create_stores_config(self, diffusion_config):
"""Test loss computation when no tokens are masked.""" """Test that post_trainer_create stores config on trainer."""
# Mock model plugin = DiffusionPlugin()
mock_model = Mock() mock_trainer = Mock()
mock_outputs = Mock() mock_cfg = Mock()
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
# Only special tokens (which won't be masked) # Set config attributes
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) for attr, value in diffusion_config.model_dump().items():
setattr(mock_cfg, attr, value)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss( plugin.post_trainer_create(mock_cfg, mock_trainer)
mock_model, input_ids
)
# Loss should be zero when no tokens are masked # Check that diffusion config was stored on trainer
assert loss.item() == 0.0 assert hasattr(mock_trainer, "diffusion_config")
assert loss.requires_grad assert mock_trainer.diffusion_config.eps == diffusion_config.eps
def test_cache_special_token_ids(self, diffusion_trainer_instance): def test_add_callbacks_post_trainer_with_generation_enabled(self):
"""Test caching of special token IDs.""" """Test callback addition when generation is enabled."""
# Should cache BOS, EOS, PAD tokens plugin = DiffusionPlugin()
expected_tokens = {0, 1, 2} # pad, bos, eos mock_trainer = Mock()
assert diffusion_trainer_instance._special_token_ids == expected_tokens mock_cfg = Mock()
def test_cache_special_token_ids_no_tokenizer(self): # Mock trainer with diffusion config that has generation enabled
"""Test caching when no tokenizer is available.""" mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.processing_class = None
trainer._cache_special_token_ids()
assert trainer._special_token_ids == set() with patch(
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
) as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
def test_main_compute_loss_interface(self, diffusion_trainer_instance): # Should return one callback
"""Test the main compute_loss interface.""" assert len(callbacks) == 1
# Mock model mock_callback_class.assert_called_once_with(mock_trainer)
mock_model = Mock()
mock_outputs = Mock()
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = { def test_add_callbacks_post_trainer_with_generation_disabled(self):
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), """Test callback addition when generation is disabled."""
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), plugin = DiffusionPlugin()
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), mock_trainer = Mock()
} mock_cfg = Mock()
# Test without return_outputs # Mock trainer with diffusion config that has generation disabled
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
assert isinstance(loss, torch.Tensor)
# Test with return_outputs callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): # Should return no callbacks
"""Test that missing input_ids raises ValueError.""" assert len(callbacks) == 0
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs) class TestLossRegistration:
"""Test loss function registration."""
def test_register_diffusion_loss(self):
"""Test that loss function can be registered."""
with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping:
result = register_diffusion_loss()
assert result is True
assert "ForDiffusionLM" in mock_mapping
assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss
def test_register_diffusion_loss_import_error(self):
"""Test fallback when LOSS_MAPPING import fails."""
# Patch the import to raise ImportError
with patch(
"builtins.__import__",
side_effect=ImportError("transformers.loss.loss_utils not found"),
):
result = register_diffusion_loss()
assert result is False