diff --git a/src/axolotl/integrations/diffusion/loss.py b/src/axolotl/integrations/diffusion/loss.py new file mode 100644 index 000000000..36661535c --- /dev/null +++ b/src/axolotl/integrations/diffusion/loss.py @@ -0,0 +1,115 @@ +"""Diffusion LM loss function for integration with transformers LOSS_MAPPING.""" + +from typing import Optional + +import torch +import torch.nn.functional as F + + +def ForDiffusionLMLoss( + logits: torch.Tensor, + labels: torch.Tensor, + vocab_size: int, + config: Optional[dict] = None, + inputs: Optional[dict] = None, + model: Optional[torch.nn.Module] = None, + **kwargs, +) -> torch.Tensor: + """ + Diffusion Language Modeling loss function. + + This function computes cross-entropy loss only on masked tokens using + diffusion info stored by the model patch during forward pass. + + Args: + logits: Model predictions [batch_size, seq_len, vocab_size] + labels: Ground truth tokens [batch_size, seq_len] + vocab_size: Size of vocabulary + config: Model configuration (contains diffusion parameters) + inputs: Input batch dictionary (contains input_ids, attention_mask) + model: The model instance (to access stored diffusion info) + **kwargs: Additional arguments + + Returns: + loss: Computed diffusion loss + """ + # Get diffusion info stored by model patch + if model is None or not hasattr(model, "_diffusion_info"): + # Fallback to regular causal LM loss if no diffusion info + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + return loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + diffusion_info = model._diffusion_info + original_input_ids = diffusion_info["original_input_ids"] + masked_indices = diffusion_info["masked_indices"] + p_mask = diffusion_info["p_mask"] + + # Get diffusion config parameters + diffusion_config = getattr(config, "diffusion_config", {}) + importance_weighting = diffusion_config.get("importance_weighting", True) + + # Check if we have any masked tokens + if not masked_indices.any(): + return torch.tensor(0.0, device=logits.device, requires_grad=True) + + # Get predictions and targets for masked positions only + masked_logits = logits[masked_indices] + masked_targets = original_input_ids[masked_indices] # Original unmasked tokens + + # Compute cross-entropy loss without reduction + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + if importance_weighting: + # Apply importance weighting: 1 / p_mask + masked_p_mask = p_mask.expand_as(masked_indices)[masked_indices] + weighted_loss = token_loss / masked_p_mask + + if labels is not None: + # For SFT data: normalize by answer length per sample + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() + + # Group losses by batch sample + batch_indices = torch.arange( + original_input_ids.shape[0], device=original_input_ids.device + ) + batch_indices = batch_indices.unsqueeze(1).expand_as(masked_indices) + masked_batch_indices = batch_indices[masked_indices] + + # Sum losses per sample and normalize by answer length + loss_per_sample = torch.zeros( + original_input_ids.shape[0], device=original_input_ids.device + ) + for i in range(original_input_ids.shape[0]): + sample_mask = masked_batch_indices == i + if sample_mask.any(): + sample_loss = weighted_loss[sample_mask].sum() + loss_per_sample[i] = sample_loss / max(answer_lengths[i], 1) + + loss = loss_per_sample.mean() + else: + # For completion data: simple average + loss = weighted_loss.mean() + else: + # No importance weighting + loss = token_loss.mean() + + return loss + + +def register_diffusion_loss(): + """Register the diffusion loss function in transformers LOSS_MAPPING.""" + try: + from transformers.loss.loss_utils import LOSS_MAPPING + + LOSS_MAPPING["ForDiffusionLM"] = ForDiffusionLMLoss + return True + except ImportError: + # Fallback for older transformers versions + return False diff --git a/src/axolotl/integrations/diffusion/model_patch.py b/src/axolotl/integrations/diffusion/model_patch.py new file mode 100644 index 000000000..bd0e4d430 --- /dev/null +++ b/src/axolotl/integrations/diffusion/model_patch.py @@ -0,0 +1,149 @@ +"""Model patches for diffusion training.""" + +import torch + + +def patch_model_for_bidirectional_attention(model): + """ + Patch model to handle diffusion training with forward process and bidirectional + attention. + + This monkey-patches the model's forward method to: + - Apply forward diffusion process (masking) during training + - Use bidirectional attention masks + - Store info for loss computation + """ + original_forward = model.forward + + def diffusion_forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + **kwargs, + ): + # Check if this is diffusion training + if ( + hasattr(self.config, "loss_type") + and self.config.loss_type == "ForDiffusionLM" + and self.training + ): + + # Store original input_ids for loss computation + original_input_ids = input_ids.clone() + + # Apply forward diffusion process (masking) + diffusion_config = getattr(self.config, "diffusion_config", {}) + noisy_input_ids, masked_indices, p_mask = _forward_process( + input_ids, attention_mask, labels, diffusion_config + ) + + # Use noisy input for model forward + input_ids = noisy_input_ids + + # Convert attention mask to bidirectional + if attention_mask is not None: + attention_mask = _create_bidirectional_attention_mask( + input_ids, attention_mask + ) + + # Store diffusion info in the model for loss computation + self._diffusion_info = { + "original_input_ids": original_input_ids, + "masked_indices": masked_indices, + "p_mask": p_mask, + } + + return original_forward( + input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs + ) + + # Replace the forward method + model.forward = diffusion_forward.__get__(model, model.__class__) + + +def _create_bidirectional_attention_mask( + input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> torch.Tensor: + """ + Create bidirectional attention mask from 2D attention mask. + + Args: + input_ids: Input token IDs [batch_size, seq_len] + attention_mask: 2D attention mask [batch_size, seq_len] + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len] + """ + batch_size, seq_len = input_ids.shape + + # Simple bidirectional mask - all tokens can attend to all valid tokens + # Expand 2D mask to 4D: [batch_size, seq_len] -> [batch_size, 1, seq_len, seq_len] + bidirectional_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S] + bidirectional_mask = bidirectional_mask.expand(batch_size, 1, seq_len, seq_len) + + # Apply row-wise masking (padded tokens can't attend to anything) + row_mask = attention_mask.unsqueeze(1).unsqueeze(3) # [B, 1, S, 1] + bidirectional_mask = bidirectional_mask & row_mask + + return bidirectional_mask + + +def _forward_process( + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + diffusion_config: dict | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Apply forward diffusion process (random masking). + + Args: + input_ids: Input token IDs [batch_size, seq_len] + attention_mask: Attention mask [batch_size, seq_len] + labels: Labels for SFT training [batch_size, seq_len] + diffusion_config: Diffusion configuration dict + + Returns: + noisy_input_ids: Input with masked tokens + masked_indices: Boolean mask of which tokens were masked + p_mask: Masking probabilities used + """ + if diffusion_config is None: + diffusion_config = {} + + batch_size, seq_len = input_ids.shape + device = input_ids.device + + eps = diffusion_config.get("eps", 1e-3) + mask_token_id = diffusion_config.get("mask_token_id", 128002) + + # Sample random timesteps for each sample + t = torch.rand(batch_size, device=device) + + # Calculate masking probability with epsilon + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask.unsqueeze(1).expand(-1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens + if attention_mask is not None: + p_mask = p_mask * attention_mask.float() + + # Create random mask based on p_mask + random_values = torch.rand_like(p_mask) + masked_indices = random_values < p_mask + + # Apply attention mask constraints + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens (where labels != -100) + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask + + # Create noisy input by replacing masked tokens + noisy_input_ids = input_ids.clone() + noisy_input_ids[masked_indices] = mask_token_id + + return noisy_input_ids, masked_indices, p_mask diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py index c31f48b03..d22b277ea 100644 --- a/src/axolotl/integrations/diffusion/plugin.py +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -7,7 +7,10 @@ from axolotl.integrations.base import BasePlugin from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -from .trainer import DiffusionTrainer +from .args import DiffusionArgs +from .callbacks import DiffusionGenerationCallback +from .loss import register_diffusion_loss +from .model_patch import patch_model_for_bidirectional_attention LOG = get_logger(__name__) @@ -24,18 +27,69 @@ class DiffusionPlugin(BasePlugin): super().__init__() self.cfg = None + if register_diffusion_loss(): + LOG.info("Registered ForDiffusionLM loss function") + else: + LOG.warning( + "Failed to register diffusion loss - older transformers version" + ) + def get_input_args(self) -> str: """Returns the pydantic model for LLaDA plugin arguments.""" return "axolotl.integrations.diffusion.DiffusionArgs" - def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): - """Perform actions after model is loaded.""" + def post_model_load( + self, cfg: DictDefault, model: PreTrainedModel | PeftModel + ): + """Configure model for diffusion training after loading.""" self.cfg = cfg - def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: - """Return custom trainer class for diffusion training.""" - return DiffusionTrainer + # Set loss type for diffusion training + if hasattr(model, "config"): + model.config.loss_type = "ForDiffusionLM" - def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): + # Store diffusion config in model config + model.config.diffusion_config = { + "eps": getattr(cfg, "eps", 1e-3), + "importance_weighting": getattr(cfg, "importance_weighting", True), + "mask_token_id": getattr(cfg, "mask_token_id", 128002), + } + + LOG.info("Configured model for diffusion training with ForDiffusionLM loss") + + # Patch model for bidirectional attention during training + patch_model_for_bidirectional_attention(model) + LOG.info("Applied bidirectional attention patch to model") + + return model + + def post_trainer_create(self, cfg: DictDefault, trainer): """Configure trainer after creation.""" - trainer.set_config(cfg) + # Create diffusion config from cfg + diffusion_config = DiffusionArgs( + noise_schedule=getattr(cfg, "noise_schedule", "linear"), + min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.1), + max_mask_ratio=getattr(cfg, "max_mask_ratio", 0.9), + num_diffusion_steps=getattr(cfg, "num_diffusion_steps", 128), + eps=getattr(cfg, "eps", 1e-3), + importance_weighting=getattr(cfg, "importance_weighting", True), + mask_token_id=getattr(cfg, "mask_token_id", 128002), + generate_samples=getattr(cfg, "generate_samples", True), + generation_interval=getattr(cfg, "generation_interval", 100), + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + generation_steps=getattr(cfg, "generation_steps", 128), + generation_temperature=getattr(cfg, "generation_temperature", 0.0), + generation_max_length=getattr(cfg, "generation_max_length", 100), + ) + + # Store diffusion config on trainer for callbacks to access + trainer.diffusion_config = diffusion_config + LOG.info("Stored diffusion config on trainer") + + def add_callbacks_post_trainer(self, cfg: DictDefault, trainer): + """Add diffusion generation callback if enabled.""" + if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples: + generation_callback = DiffusionGenerationCallback(trainer) + LOG.info("Added diffusion generation callback") + return [generation_callback] + return [] diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py deleted file mode 100644 index dc62035d5..000000000 --- a/src/axolotl/integrations/diffusion/trainer.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Custom trainer for diffusion LM training.""" - -from typing import Any, Literal - -import torch -import torch.nn.functional as F -from torch import nn - -from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import get_logger - -from .callbacks import DiffusionGenerationCallback - -LOG = get_logger(__name__) - - -class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors - """Custom trainer for diffusion LM training that overrides loss computation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.config = None - self._special_token_ids = None - - def set_config(self, config: DictDefault): - """Set config for diffusion training.""" - self.config = config - self._cache_special_token_ids() - - if config.generate_samples: - generation_callback = DiffusionGenerationCallback(self) - self.add_callback(generation_callback) - - def compute_loss( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor], - return_outputs: bool = False, - num_items_in_batch: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - """Override compute_loss to use diffusion loss.""" - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask") - labels = inputs.get("labels") - - if input_ids is None: - raise ValueError("input_ids is required for diffusion training") - - loss, outputs = self._compute_diffusion_loss( - model, input_ids, attention_mask, labels - ) - - if return_outputs: - return loss, outputs - return loss - - def _cache_special_token_ids(self): - """Cache special token IDs to avoid repeated tokenizer access.""" - if self.processing_class is None: - self._special_token_ids = set() - return - - tokenizer = self.processing_class - special_tokens = set() - - if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: - special_tokens.add(tokenizer.bos_token_id) - if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: - special_tokens.add(tokenizer.eos_token_id) - if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: - special_tokens.add(tokenizer.pad_token_id) - - self._special_token_ids = special_tokens - - @torch.compile - def _forward_process( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - eps: float = 1e-3, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward noising process. A timestep is sampled along the process, and tokens are - masked with probability determined by the configured noise schedule. - - Args: - input_ids: Input token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len]. - labels: Labels for SFT training [batch_size, seq_len]. - eps: Small epsilon value for minimum masking probability. - - Returns: - noisy_batch: Input with some tokens masked. - masked_indices: Boolean mask indicating which tokens were masked. - p_mask: Masking probabilities for each token [batch_size, seq_len]. - """ - batch_size, seq_len = input_ids.shape - device = input_ids.device - - # Sample random timesteps for each sample in batch - t = torch.rand(batch_size, device=device) - - # Calculate masking probability with epsilon - p_mask = (1 - eps) * t + eps # [batch_size] - p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] - - # Don't mask padding tokens if attention_mask is provided - if attention_mask is not None: - valid_mask = attention_mask.bool() - p_mask = p_mask * valid_mask.float() - - # Create mask to exclude special tokens - special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) - if self._special_token_ids: - for token_id in self._special_token_ids: - special_token_mask |= input_ids == token_id - - # Create random mask based on p_mask - masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask - masked_indices = masked_indices & ~special_token_mask - if attention_mask is not None: - masked_indices = masked_indices & attention_mask.bool() - - # For SFT data, only mask answer tokens - if labels is not None: - answer_mask = labels != -100 - masked_indices = masked_indices & answer_mask - - # Create masked input - mask_token_id = self.config.mask_token_id - noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) - - return noisy_batch, masked_indices, p_mask - - @torch.compile - def _create_bidirectional_attention_mask( - self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Create bidirectional attention mask to override default causal masking. Handles - sample-packed sequences where different samples are identified by different - attention mask values. - - Args: - input_ids: Input token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len] - - Returns: - bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. - """ - batch_size, seq_len = input_ids.shape - device = input_ids.device - - if attention_mask is None or not self.config.sample_packing: - return torch.ones( - batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device - ) - - # Create attention mask by comparing sample IDs element-wise - mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] - mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] - - # Tokens can attend to each other if they have the same non-zero sample ID - bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) - - # Add head dimension: [batch_size, 1, seq_len, seq_len] - bidirectional_mask = bidirectional_mask.unsqueeze(1) - - return bidirectional_mask - - def _compute_diffusion_loss( - self, - model: nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | Any]: - """ - Compute diffusion loss. - - Args: - model: The model to compute loss for. - input_ids: Ground truth token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len]. - labels: Labels for SFT training [batch_size, seq_len]. - - Returns: - loss: Cross-entropy loss. - metrics: Dictionary of metrics. - """ - # Apply forward process - noisy_batch, masked_indices, p_mask = self._forward_process( - input_ids, attention_mask, labels, self.config.eps - ) - - # Create bidirectional attention mask - bidirectional_mask = self._create_bidirectional_attention_mask( - input_ids, attention_mask - ) - - # Forward pass - outputs = model( - input_ids=noisy_batch, - attention_mask=bidirectional_mask, - ) - logits = outputs.logits - - if masked_indices.sum() > 0: - valid_indices = torch.where(masked_indices) - batch_indices, seq_indices = valid_indices - - masked_logits = logits[batch_indices, seq_indices] - masked_targets = input_ids[batch_indices, seq_indices] - masked_p_mask = p_mask[batch_indices, seq_indices] - - # Compute cross-entropy loss without reduction - token_loss = F.cross_entropy( - masked_logits.float(), masked_targets, reduction="none" - ) - - if self.config.importance_weighting: - masked_p_mask = masked_p_mask.float() - weighted_loss = token_loss / masked_p_mask - else: - weighted_loss = token_loss - - # Final loss: sum weighted losses, normalize - if labels is not None: - # For SFT data: normalize by answer length per sample - answer_mask = labels != -100 - answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] - - # Get batch indices for masked tokens - masked_batch_indices = batch_indices - - # Sum losses per sample and divide by answer length - loss_per_sample = torch.zeros( - input_ids.shape[0], device=input_ids.device - ) - for i in range(input_ids.shape[0]): - sample_mask = masked_batch_indices == i - if sample_mask.sum() > 0: - sample_loss = weighted_loss[sample_mask].sum() - loss_per_sample[i] = sample_loss / answer_lengths[i] - - loss = loss_per_sample.mean() - else: - # Original normalization for non-SFT data - loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) - - ce_loss = token_loss.mean() - - # Compute accuracy on masked tokens - with torch.no_grad(): - pred_tokens = masked_logits.argmax(dim=-1) - accuracy = (pred_tokens == masked_targets).float().mean() - else: - loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) - accuracy = torch.tensor(0.0, device=input_ids.device) - ce_loss = torch.tensor(0.0, device=input_ids.device) - masked_p_mask = torch.tensor(1.0, device=input_ids.device) - - metrics = { - "loss": loss.item(), - "accuracy": accuracy.item(), - "mask_ratio": masked_indices.float().mean().item(), - "num_masked_tokens": (masked_indices.sum().item(), "sum"), - "avg_p_mask": p_mask[masked_indices].mean().item(), - "ce_loss": ce_loss.item(), - } - if self.config.importance_weighting: - metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() - - train_eval: Literal["train", "eval"] = "train" if model.training else "eval" - self.store_metrics(metrics, train_eval=train_eval) - - return loss, outputs diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py index 583597238..37709e349 100644 --- a/tests/integrations/test_diffusion.py +++ b/tests/integrations/test_diffusion.py @@ -2,111 +2,180 @@ # pylint: disable=redefined-outer-name,protected-access -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest import torch -from axolotl.integrations.diffusion.trainer import DiffusionTrainer -from axolotl.utils.dict import DictDefault - - -@pytest.fixture -def mock_tokenizer(): - """Create a mock tokenizer.""" - tokenizer = Mock() - tokenizer.bos_token_id = 1 - tokenizer.eos_token_id = 2 - tokenizer.pad_token_id = 0 - return tokenizer +from axolotl.integrations.diffusion.args import DiffusionArgs +from axolotl.integrations.diffusion.loss import ( + ForDiffusionLMLoss, + register_diffusion_loss, +) +from axolotl.integrations.diffusion.model_patch import ( + _create_bidirectional_attention_mask, + _forward_process, + patch_model_for_bidirectional_attention, +) +from axolotl.integrations.diffusion.plugin import DiffusionPlugin @pytest.fixture def diffusion_config(): """Create a diffusion config.""" - return DictDefault( - { - "mask_token_id": 32000, - "eps": 1e-3, - "importance_weighting": False, - "sample_packing": False, - } + return DiffusionArgs( + eps=1e-3, + importance_weighting=False, + mask_token_id=32000, + generate_samples=False, ) @pytest.fixture -def diffusion_trainer_instance(mock_tokenizer, diffusion_config): - """Create a diffusion trainer instance for testing methods directly.""" - # Create a minimal trainer instance just for testing methods - trainer = object.__new__(DiffusionTrainer) # Bypass __init__ - trainer.config = diffusion_config - trainer._special_token_ids = {0, 1, 2} # pad, bos, eos - trainer.processing_class = mock_tokenizer - trainer.store_metrics = Mock() # Mock metrics storage - return trainer +def mock_model(): + """Create a mock model.""" + model = Mock() + model.config = Mock() + model.config.loss_type = "ForDiffusionLM" + model.config.diffusion_config = { + "eps": 1e-3, + "importance_weighting": False, + "mask_token_id": 32000, + } + model.training = True + return model -class TestDiffusionTrainer: - """Test the DiffusionTrainer class.""" +class TestDiffusionLoss: + """Test the ForDiffusionLMLoss function.""" - def test_forward_process_basic(self, diffusion_trainer_instance): - """Test basic forward process without labels.""" + def test_loss_with_diffusion_info(self, mock_model): + """Test loss computation with stored diffusion info.""" + batch_size, seq_len, vocab_size = 1, 5, 1000 + + # Mock stored diffusion info + original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + masked_indices = torch.tensor( + [[False, True, True, False, False]], dtype=torch.bool + ) + p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float) + + mock_model._diffusion_info = { + "original_input_ids": original_input_ids, + "masked_indices": masked_indices, + "p_mask": p_mask, + } + + # Mock logits + logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + loss = ForDiffusionLMLoss( + logits=logits, + labels=labels, + vocab_size=vocab_size, + config=mock_model.config, + model=mock_model, + ) + + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + assert loss.item() >= 0 + + def test_loss_fallback_without_diffusion_info(self, mock_model): + """Test fallback to causal LM loss when no diffusion info.""" + batch_size, seq_len, vocab_size = 1, 5, 1000 + + # Remove diffusion info to trigger fallback + if hasattr(mock_model, "_diffusion_info"): + delattr(mock_model, "_diffusion_info") + + logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) + labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + loss = ForDiffusionLMLoss( + logits=logits, + labels=labels, + vocab_size=vocab_size, + config=mock_model.config, + model=mock_model, + ) + + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + + def test_loss_no_masked_tokens(self, mock_model): + """Test loss when no tokens are masked.""" + batch_size, seq_len, vocab_size = 1, 3, 1000 + + # No masked tokens + original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long) + masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool) + p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float) + + mock_model._diffusion_info = { + "original_input_ids": original_input_ids, + "masked_indices": masked_indices, + "p_mask": p_mask, + } + + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.tensor([[1, 10, 2]], dtype=torch.long) + + loss = ForDiffusionLMLoss( + logits=logits, + labels=labels, + vocab_size=vocab_size, + config=mock_model.config, + model=mock_model, + ) + + assert loss.item() == 0.0 + + +class TestModelPatch: + """Test the model patching functionality.""" + + def test_forward_process_basic(self): + """Test basic forward process.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + diffusion_config = {"eps": 0.1, "mask_token_id": 32000} - noisy_batch, masked_indices, p_mask = ( - diffusion_trainer_instance._forward_process(input_ids, eps=0.1) + noisy_input_ids, masked_indices, p_mask = _forward_process( + input_ids, diffusion_config=diffusion_config ) # Check shapes - assert noisy_batch.shape == input_ids.shape + assert noisy_input_ids.shape == input_ids.shape assert masked_indices.shape == input_ids.shape assert p_mask.shape == input_ids.shape - # Check that special tokens are not masked - special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) - assert not masked_indices[special_token_positions].any() + # Check that mask token is applied where masked + if masked_indices.any(): + assert (noisy_input_ids[masked_indices] == 32000).all() - # Check that mask token is applied - mask_token_id = diffusion_trainer_instance._config.mask_token_id - masked_positions = masked_indices - if masked_positions.any(): - assert (noisy_batch[masked_positions] == mask_token_id).all() - - def test_forward_process_with_labels(self, diffusion_trainer_instance): + def test_forward_process_with_labels(self): """Test forward process with SFT labels.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + diffusion_config = {"eps": 0.1, "mask_token_id": 32000} - noisy_batch, masked_indices, p_mask = ( - diffusion_trainer_instance._forward_process( - input_ids, labels=labels, eps=0.1 - ) + _, masked_indices, _ = _forward_process( + input_ids, labels=labels, diffusion_config=diffusion_config ) - # Check shapes - assert noisy_batch.shape == input_ids.shape - assert masked_indices.shape == input_ids.shape - assert p_mask.shape == input_ids.shape - # Check that only answer tokens can be masked (where labels != -100) non_answer_mask = labels == -100 - - # No masking should occur on non-answer tokens assert not masked_indices[non_answer_mask].any() - # p_mask should be the same for all positions (sampled timestep), - # but masking is only applied to answer tokens - assert p_mask.shape == input_ids.shape - # Verify that masked_indices respects the answer mask - assert not masked_indices[non_answer_mask].any() - - def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): + def test_forward_process_with_attention_mask(self): """Test forward process with attention mask.""" input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + diffusion_config = {"eps": 0.1, "mask_token_id": 32000} - _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( - input_ids, attention_mask=attention_mask, eps=0.1 + _, masked_indices, p_mask = _forward_process( + input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config ) # Check that padding tokens are not masked @@ -114,158 +183,146 @@ class TestDiffusionTrainer: assert not masked_indices[padding_positions].any() assert (p_mask[padding_positions] == 0).all() - def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): - """Test bidirectional attention mask without sample packing.""" + def test_bidirectional_attention_mask(self): + """Test bidirectional attention mask creation.""" input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long) - mask = diffusion_trainer_instance._create_bidirectional_attention_mask( - input_ids - ) + mask = _create_bidirectional_attention_mask(input_ids, attention_mask) # Should be all-to-all attention expected_shape = (1, 1, 4, 4) assert mask.shape == expected_shape assert mask.all() - def test_bidirectional_attention_mask_with_packing( - self, diffusion_trainer_instance - ): - """Test bidirectional attention mask with sample packing.""" - diffusion_trainer_instance._config.sample_packing = True - input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) - # Sample IDs: first sample (1), second sample (2) - attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) + def test_bidirectional_attention_mask_with_padding(self): + """Test bidirectional attention mask with padding.""" + input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) - mask = diffusion_trainer_instance._create_bidirectional_attention_mask( - input_ids, attention_mask - ) + mask = _create_bidirectional_attention_mask(input_ids, attention_mask) - # Check that tokens within same sample can attend to each other - # but not across samples - assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other - assert mask[0, 0, 1, 2].item() - assert not mask[0, 0, 0, 3].item() # Can't attend across samples - assert not mask[0, 0, 2, 4].item() - assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other + # Padding positions should not attend or be attended to + assert not mask[0, 0, 3, :].any() # Padding can't attend to anything + assert not mask[0, 0, :, 3].any() # Nothing can attend to padding - def test_compute_loss_basic(self, diffusion_trainer_instance): - """Test basic loss computation.""" - # Mock model that returns logits + def test_patch_model_for_bidirectional_attention(self): + """Test that model patching works.""" mock_model = Mock() - mock_outputs = Mock() - vocab_size = 1000 - seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) - mock_model.return_value = mock_outputs + mock_model.config = Mock() + mock_model.config.loss_type = "ForDiffusionLM" + mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000} mock_model.training = True - input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + original_forward = Mock() + mock_model.forward = original_forward - loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids - ) + # Patch the model + patch_model_for_bidirectional_attention(mock_model) - # Check that loss is computed - assert isinstance(loss, torch.Tensor) - assert loss.requires_grad - assert outputs == mock_outputs + # Check that forward method was replaced + assert mock_model.forward != original_forward - # Check that metrics were stored - diffusion_trainer_instance.store_metrics.assert_called_once() - def test_compute_loss_with_labels(self, diffusion_trainer_instance): - """Test loss computation with SFT labels.""" - # Mock model +class TestDiffusionPlugin: + """Test the DiffusionPlugin.""" + + def test_plugin_registers_loss_function(self): + """Test that plugin registers diffusion loss function.""" + with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register: + plugin = DiffusionPlugin() + mock_register.assert_called_once() + + def test_post_model_load_configuration(self): + """Test that post_model_load configures model correctly.""" + plugin = DiffusionPlugin() + + # Mock model and config mock_model = Mock() - mock_outputs = Mock() - vocab_size = 1000 - seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) - mock_model.return_value = mock_outputs - mock_model.training = True + mock_model.config = Mock() + mock_cfg = Mock() + mock_cfg.eps = 1e-3 + mock_cfg.importance_weighting = True + mock_cfg.mask_token_id = 32000 + + with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch: + result = plugin.post_model_load(mock_cfg, mock_model) + + # Check model configuration + assert mock_model.config.loss_type == "ForDiffusionLM" + assert mock_model.config.diffusion_config is not None + assert mock_model.config.diffusion_config['eps'] == 1e-3 + + # Check model was patched + mock_patch.assert_called_once_with(mock_model) + + # Should return the model + assert result == mock_model - input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) - labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + def test_post_trainer_create_stores_config(self, diffusion_config): + """Test that post_trainer_create stores config on trainer.""" + plugin = DiffusionPlugin() + mock_trainer = Mock() + mock_cfg = Mock() + + # Set config attributes + for attr, value in diffusion_config.model_dump().items(): + setattr(mock_cfg, attr, value) + + plugin.post_trainer_create(mock_cfg, mock_trainer) + + # Check that diffusion config was stored on trainer + assert hasattr(mock_trainer, 'diffusion_config') + assert mock_trainer.diffusion_config.eps == diffusion_config.eps - loss, _ = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids, labels=labels - ) + def test_add_callbacks_post_trainer_with_generation_enabled(self): + """Test callback addition when generation is enabled.""" + plugin = DiffusionPlugin() + mock_trainer = Mock() + mock_cfg = Mock() + + # Mock trainer with diffusion config that has generation enabled + mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True) + + with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class: + callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) + + # Should return one callback + assert len(callbacks) == 1 + mock_callback_class.assert_called_once_with(mock_trainer) - # Check that loss is computed - assert isinstance(loss, torch.Tensor) - assert loss.requires_grad + def test_add_callbacks_post_trainer_with_generation_disabled(self): + """Test callback addition when generation is disabled.""" + plugin = DiffusionPlugin() + mock_trainer = Mock() + mock_cfg = Mock() + + # Mock trainer with diffusion config that has generation disabled + mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False) + + callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) + + # Should return no callbacks + assert len(callbacks) == 0 - # Check that SFT metrics were added - call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] - assert "answer_ratio" in call_args - assert "avg_answer_length" in call_args - def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): - """Test loss computation when no tokens are masked.""" - # Mock model - mock_model = Mock() - mock_outputs = Mock() - vocab_size = 1000 - seq_len = 3 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size) - mock_model.return_value = mock_outputs - mock_model.training = True +class TestLossRegistration: + """Test loss function registration.""" - # Only special tokens (which won't be masked) - input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + def test_register_diffusion_loss(self): + """Test that loss function can be registered.""" + with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping: + result = register_diffusion_loss() + assert result is True + assert "ForDiffusionLM" in mock_mapping + assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss - loss, _ = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids - ) - - # Loss should be zero when no tokens are masked - assert loss.item() == 0.0 - assert loss.requires_grad - - def test_cache_special_token_ids(self, diffusion_trainer_instance): - """Test caching of special token IDs.""" - # Should cache BOS, EOS, PAD tokens - expected_tokens = {0, 1, 2} # pad, bos, eos - assert diffusion_trainer_instance._special_token_ids == expected_tokens - - def test_cache_special_token_ids_no_tokenizer(self): - """Test caching when no tokenizer is available.""" - trainer = object.__new__(DiffusionTrainer) # Bypass __init__ - trainer.processing_class = None - trainer._cache_special_token_ids() - - assert trainer._special_token_ids == set() - - def test_main_compute_loss_interface(self, diffusion_trainer_instance): - """Test the main compute_loss interface.""" - # Mock model - mock_model = Mock() - mock_outputs = Mock() - mock_outputs.logits = torch.randn(1, 5, 1000) - mock_model.return_value = mock_outputs - mock_model.training = True - - inputs = { - "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), - "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), - "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), - } - - # Test without return_outputs - loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) - assert isinstance(loss, torch.Tensor) - - # Test with return_outputs - loss, outputs = diffusion_trainer_instance.compute_loss( - mock_model, inputs, return_outputs=True - ) - assert isinstance(loss, torch.Tensor) - assert outputs == mock_outputs - - def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): - """Test that missing input_ids raises ValueError.""" - mock_model = Mock() - inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} - - with pytest.raises(ValueError, match="input_ids is required"): - diffusion_trainer_instance.compute_loss(mock_model, inputs) + def test_register_diffusion_loss_import_error(self): + """Test fallback when LOSS_MAPPING import fails.""" + # Patch the import to raise ImportError + with patch( + "builtins.__import__", + side_effect=ImportError("transformers.loss.loss_utils not found"), + ): + result = register_diffusion_loss() + assert result is False