Compare commits

..

1 Commits

Author SHA1 Message Date
Dan Saunders
1f75287a3a diffusion custom models approach 2025-08-19 04:09:46 +00:00
9 changed files with 855 additions and 603 deletions

View File

@@ -64,11 +64,25 @@ learning_rate: 3e-4
## Supported Models ## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please Currently supported base model types:
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)! - **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
## How It Works ## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking ### Random Masking
During training, tokens are randomly masked based on a sampled timestep: During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1] - Sample timestep `t` uniformly from [0, 1]
@@ -76,11 +90,10 @@ During training, tokens are randomly masked based on a sampled timestep:
- Randomly mask tokens with probability `p` - Randomly mask tokens with probability `p`
### Bidirectional Attention ### Bidirectional Attention
The plugin uses native 4D attention masks to: The models override causal attention with bidirectional attention:
- Enable bidirectional attention without patches - Creates 4D attention masks allowing all-to-all attention
- Allow all tokens to attend to all other tokens - Maintains proper padding and sample packing masks
- Maintain proper padding masks - Compatible with standard HuggingFace attention implementations
- Work with modern `transformers` models out of the box
### Diffusion Loss ### Diffusion Loss
@@ -90,6 +103,22 @@ Loss is computed only on masked tokens with (optional) importance weighting:
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
``` ```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation ## Sample Generation
When `generate_samples: true`, the plugin generates samples during training: When `generate_samples: true`, the plugin generates samples during training:
@@ -115,9 +144,19 @@ The plugin adds several metrics to track diffusion training:
- `train/ce_loss`: Unweighted cross-entropy loss - `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight - `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations ## Limitations
- No flash attention support - **Model Support**: Currently limited to Llama and Mistral architectures
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
## References ## References

View File

@@ -1,6 +1,26 @@
"""Diffusion LM training plugin init.""" """Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin from .plugin import DiffusionPlugin
__all__ = ["DiffusionArgs", "DiffusionPlugin"] # Register custom configurations
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]

View File

@@ -26,29 +26,31 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs, **kwargs,
): ):
"""Generate samples at specified intervals.""" """Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if ( if (
state.global_step > 0 state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0 and state.global_step % config.get('generation_interval', 100) == 0
): ):
# Use eval dataloader if available, otherwise use train dataloader # Use eval dataloader if available, otherwise use train dataloader
if ( if (
hasattr(self.trainer, "eval_dataset") hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None and self.trainer.eval_dataset is not None
): ):
dataloader = self.trainer.callback_handler.eval_dataloader dataloader = self.trainer.get_eval_dataloader()
else: else:
dataloader = self.trainer.callback_handler.train_dataloader dataloader = self.trainer.get_train_dataloader()
# Generate samples # Generate samples
samples = generate_samples( samples = generate_samples(
model=self.trainer.model, model=self.trainer.model,
tokenizer=self.trainer.tokenizer, tokenizer=self.trainer.tokenizer,
dataloader=dataloader, dataloader=dataloader,
num_generation_samples=self.trainer.config.num_generation_samples, num_generation_samples=config.get('num_generation_samples', 3),
max_length=self.trainer.config.generation_max_length, max_length=config.get('generation_max_length', 256),
num_diffusion_steps=self.trainer.config.generation_steps, num_diffusion_steps=config.get('generation_steps', 10),
temperature=self.trainer.config.generation_temperature, temperature=config.get('generation_temperature', 1.0),
mask_token_id=self.trainer.config.mask_token_id, mask_token_id=config.get('mask_token_id', 32000),
) )
# Log samples # Log samples
@@ -81,7 +83,8 @@ class DiffusionGenerationCallback(TrainerCallback):
LOG.info("=" * 60) LOG.info("=" * 60)
if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero: config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if wandb.run is not None: if wandb.run is not None:
wandb.log( wandb.log(
{ {

View File

@@ -0,0 +1,71 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,115 +0,0 @@
"""Diffusion LM loss function for integration with transformers LOSS_MAPPING."""
from typing import Optional
import torch
import torch.nn.functional as F
def ForDiffusionLMLoss(
logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
config: Optional[dict] = None,
inputs: Optional[dict] = None,
model: Optional[torch.nn.Module] = None,
**kwargs,
) -> torch.Tensor:
"""
Diffusion Language Modeling loss function.
This function computes cross-entropy loss only on masked tokens using
diffusion info stored by the model patch during forward pass.
Args:
logits: Model predictions [batch_size, seq_len, vocab_size]
labels: Ground truth tokens [batch_size, seq_len]
vocab_size: Size of vocabulary
config: Model configuration (contains diffusion parameters)
inputs: Input batch dictionary (contains input_ids, attention_mask)
model: The model instance (to access stored diffusion info)
**kwargs: Additional arguments
Returns:
loss: Computed diffusion loss
"""
# Get diffusion info stored by model patch
if model is None or not hasattr(model, "_diffusion_info"):
# Fallback to regular causal LM loss if no diffusion info
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
return loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
diffusion_info = model._diffusion_info
original_input_ids = diffusion_info["original_input_ids"]
masked_indices = diffusion_info["masked_indices"]
p_mask = diffusion_info["p_mask"]
# Get diffusion config parameters
diffusion_config = getattr(config, "diffusion_config", {})
importance_weighting = diffusion_config.get("importance_weighting", True)
# Check if we have any masked tokens
if not masked_indices.any():
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Get predictions and targets for masked positions only
masked_logits = logits[masked_indices]
masked_targets = original_input_ids[masked_indices] # Original unmasked tokens
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if importance_weighting:
# Apply importance weighting: 1 / p_mask
masked_p_mask = p_mask.expand_as(masked_indices)[masked_indices]
weighted_loss = token_loss / masked_p_mask
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float()
# Group losses by batch sample
batch_indices = torch.arange(
original_input_ids.shape[0], device=original_input_ids.device
)
batch_indices = batch_indices.unsqueeze(1).expand_as(masked_indices)
masked_batch_indices = batch_indices[masked_indices]
# Sum losses per sample and normalize by answer length
loss_per_sample = torch.zeros(
original_input_ids.shape[0], device=original_input_ids.device
)
for i in range(original_input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.any():
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / max(answer_lengths[i], 1)
loss = loss_per_sample.mean()
else:
# For completion data: simple average
loss = weighted_loss.mean()
else:
# No importance weighting
loss = token_loss.mean()
return loss
def register_diffusion_loss():
"""Register the diffusion loss function in transformers LOSS_MAPPING."""
try:
from transformers.loss.loss_utils import LOSS_MAPPING
LOSS_MAPPING["ForDiffusionLM"] = ForDiffusionLMLoss
return True
except ImportError:
# Fallback for older transformers versions
return False

View File

@@ -1,149 +0,0 @@
"""Model patches for diffusion training."""
import torch
def patch_model_for_bidirectional_attention(model):
"""
Patch model to handle diffusion training with forward process and bidirectional
attention.
This monkey-patches the model's forward method to:
- Apply forward diffusion process (masking) during training
- Use bidirectional attention masks
- Store info for loss computation
"""
original_forward = model.forward
def diffusion_forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs,
):
# Check if this is diffusion training
if (
hasattr(self.config, "loss_type")
and self.config.loss_type == "ForDiffusionLM"
and self.training
):
# Store original input_ids for loss computation
original_input_ids = input_ids.clone()
# Apply forward diffusion process (masking)
diffusion_config = getattr(self.config, "diffusion_config", {})
noisy_input_ids, masked_indices, p_mask = _forward_process(
input_ids, attention_mask, labels, diffusion_config
)
# Use noisy input for model forward
input_ids = noisy_input_ids
# Convert attention mask to bidirectional
if attention_mask is not None:
attention_mask = _create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Store diffusion info in the model for loss computation
self._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
return original_forward(
input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
)
# Replace the forward method
model.forward = diffusion_forward.__get__(model, model.__class__)
def _create_bidirectional_attention_mask(
input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Create bidirectional attention mask from 2D attention mask.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: 2D attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
"""
batch_size, seq_len = input_ids.shape
# Simple bidirectional mask - all tokens can attend to all valid tokens
# Expand 2D mask to 4D: [batch_size, seq_len] -> [batch_size, 1, seq_len, seq_len]
bidirectional_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
bidirectional_mask = bidirectional_mask.expand(batch_size, 1, seq_len, seq_len)
# Apply row-wise masking (padded tokens can't attend to anything)
row_mask = attention_mask.unsqueeze(1).unsqueeze(3) # [B, 1, S, 1]
bidirectional_mask = bidirectional_mask & row_mask
return bidirectional_mask
def _forward_process(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
diffusion_config: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply forward diffusion process (random masking).
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Labels for SFT training [batch_size, seq_len]
diffusion_config: Diffusion configuration dict
Returns:
noisy_input_ids: Input with masked tokens
masked_indices: Boolean mask of which tokens were masked
p_mask: Masking probabilities used
"""
if diffusion_config is None:
diffusion_config = {}
batch_size, seq_len = input_ids.shape
device = input_ids.device
eps = diffusion_config.get("eps", 1e-3)
mask_token_id = diffusion_config.get("mask_token_id", 128002)
# Sample random timesteps for each sample
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask.unsqueeze(1).expand(-1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens
if attention_mask is not None:
p_mask = p_mask * attention_mask.float()
# Create random mask based on p_mask
random_values = torch.rand_like(p_mask)
masked_indices = random_values < p_mask
# Apply attention mask constraints
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens (where labels != -100)
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create noisy input by replacing masked tokens
noisy_input_ids = input_ids.clone()
noisy_input_ids[masked_indices] = mask_token_id
return noisy_input_ids, masked_indices, p_mask

View File

@@ -0,0 +1,426 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,16 +1,20 @@
"""Diffusion LM training plugin for Axolotl.""" """Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel from transformers import AutoConfig, AutoModel, PreTrainedModel
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .args import DiffusionArgs
from .callbacks import DiffusionGenerationCallback from .callbacks import DiffusionGenerationCallback
from .loss import register_diffusion_loss from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .model_patch import patch_model_for_bidirectional_attention from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -27,70 +31,68 @@ class DiffusionPlugin(BasePlugin):
super().__init__() super().__init__()
self.cfg = None self.cfg = None
if register_diffusion_loss():
LOG.info("Registered ForDiffusionLM loss function")
else:
LOG.warning(
"Failed to register diffusion loss - older transformers version"
)
def get_input_args(self) -> str: def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model for diffusion training after loading.""" """Configure model after loading."""
self.cfg = cfg self.cfg = cfg
# Set loss type for diffusion training # Set tokenizer on diffusion models for special token handling
if hasattr(model, "config"): if hasattr(model, "set_tokenizer"):
model.config.loss_type = "ForDiffusionLM" # Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
# Store diffusion config in model config def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
model.config.diffusion_config = { """Add diffusion-specific callbacks after trainer creation."""
"eps": getattr(cfg, "eps", 1e-3), callbacks = []
"importance_weighting": getattr(cfg, "importance_weighting", True),
"mask_token_id": getattr(cfg, "mask_token_id", 128002),
}
LOG.info("Configured model for diffusion training with ForDiffusionLM loss") # Store diffusion config on trainer for callbacks
trainer.diffusion_config = cfg
# Patch model for bidirectional attention during training # Add generation callback if enabled
patch_model_for_bidirectional_attention(model) if cfg.get("generate_samples", False):
LOG.info("Applied bidirectional attention patch to model")
return model
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Configure trainer after creation."""
# Create diffusion config from cfg
diffusion_config = DiffusionArgs(
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.1),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 0.9),
num_diffusion_steps=getattr(cfg, "num_diffusion_steps", 128),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", True),
mask_token_id=getattr(cfg, "mask_token_id", 128002),
generate_samples=getattr(cfg, "generate_samples", True),
generation_interval=getattr(cfg, "generation_interval", 100),
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
generation_steps=getattr(cfg, "generation_steps", 128),
generation_temperature=getattr(cfg, "generation_temperature", 0.0),
generation_max_length=getattr(cfg, "generation_max_length", 100),
)
# Store diffusion config on trainer for callbacks to access
trainer.diffusion_config = diffusion_config
LOG.info("Stored diffusion config on trainer")
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
"""Add diffusion generation callback if enabled."""
if (
hasattr(trainer, "diffusion_config")
and trainer.diffusion_config.generate_samples
):
generation_callback = DiffusionGenerationCallback(trainer) generation_callback = DiffusionGenerationCallback(trainer)
LOG.info("Added diffusion generation callback") callbacks.append(generation_callback)
return [generation_callback]
return [] return callbacks

View File

@@ -1,4 +1,4 @@
"""Tests for diffusion trainer integration.""" """Tests for diffusion model integration."""
# pylint: disable=redefined-outer-name,protected-access # pylint: disable=redefined-outer-name,protected-access
@@ -7,175 +7,114 @@ from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from axolotl.integrations.diffusion.args import DiffusionArgs from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
from axolotl.integrations.diffusion.loss import ( from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
ForDiffusionLMLoss, from axolotl.utils.dict import DictDefault
register_diffusion_loss,
)
from axolotl.integrations.diffusion.model_patch import ( @pytest.fixture
_create_bidirectional_attention_mask, def mock_tokenizer():
_forward_process, """Create a mock tokenizer."""
patch_model_for_bidirectional_attention, tokenizer = Mock()
) tokenizer.bos_token_id = 1
from axolotl.integrations.diffusion.plugin import DiffusionPlugin tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
@pytest.fixture @pytest.fixture
def diffusion_config(): def diffusion_config():
"""Create a diffusion config.""" """Create a diffusion config."""
return DiffusionArgs( return LlamaForDiffusionConfig(
mask_token_id=32000,
eps=1e-3, eps=1e-3,
importance_weighting=False, importance_weighting=False,
mask_token_id=32000, sample_packing=False,
generate_samples=False, # Basic llama config fields - smaller for testing
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
) )
@pytest.fixture @pytest.fixture
def mock_model(): def diffusion_model_instance(mock_tokenizer, diffusion_config):
"""Create a mock model.""" """Create a diffusion model instance for testing methods directly."""
model = Mock() # Create a minimal model instance for testing
model.config = Mock() model = object.__new__(LlamaForDiffusionLM)
model.config.loss_type = "ForDiffusionLM" model.config = diffusion_config
model.config.diffusion_config = { model._special_token_ids = {0, 1, 2} # pad, bos, eos
"eps": 1e-3,
"importance_weighting": False,
"mask_token_id": 32000,
}
model.training = True model.training = True
# Set tokenizer
model.set_tokenizer(mock_tokenizer)
return model return model
class TestDiffusionLoss: class TestDiffusionModel:
"""Test the ForDiffusionLMLoss function.""" """Test the DiffusionModel class."""
def test_loss_with_diffusion_info(self, mock_model): def test_forward_process_basic(self, diffusion_model_instance):
"""Test loss computation with stored diffusion info.""" """Test basic forward process without labels."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Mock stored diffusion info
original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
masked_indices = torch.tensor(
[[False, True, True, False, False]], dtype=torch.bool
)
p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
# Mock logits
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert loss.item() >= 0
def test_loss_fallback_without_diffusion_info(self, mock_model):
"""Test fallback to causal LM loss when no diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Remove diffusion info to trigger fallback
if hasattr(mock_model, "_diffusion_info"):
delattr(mock_model, "_diffusion_info")
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_loss_no_masked_tokens(self, mock_model):
"""Test loss when no tokens are masked."""
batch_size, seq_len, vocab_size = 1, 3, 1000
# No masked tokens
original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long)
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.tensor([[1, 10, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert loss.item() == 0.0
class TestModelPatch:
"""Test the model patching functionality."""
def test_forward_process_basic(self):
"""Test basic forward process."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_input_ids, masked_indices, p_mask = _forward_process( noisy_batch, masked_indices, p_mask = (
input_ids, diffusion_config=diffusion_config diffusion_model_instance._forward_process(input_ids, eps=0.1)
) )
# Check shapes # Check shapes
assert noisy_input_ids.shape == input_ids.shape assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape assert p_mask.shape == input_ids.shape
# Check that mask token is applied where masked # Check that special tokens are not masked
if masked_indices.any(): special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert (noisy_input_ids[masked_indices] == 32000).all() assert not masked_indices[special_token_positions].any()
def test_forward_process_with_labels(self): # Check that mask token is applied
mask_token_id = diffusion_model_instance.config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_model_instance):
"""Test forward process with SFT labels.""" """Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
_, masked_indices, _ = _forward_process( noisy_batch, masked_indices, p_mask = (
input_ids, labels=labels, diffusion_config=diffusion_config diffusion_model_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
) )
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100) # Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100 non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any() assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self): # p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
"""Test forward process with attention mask.""" """Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
_, masked_indices, p_mask = _forward_process( _, masked_indices, p_mask = diffusion_model_instance._forward_process(
input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config input_ids, attention_mask=attention_mask, eps=0.1
) )
# Check that padding tokens are not masked # Check that padding tokens are not masked
@@ -183,153 +122,169 @@ class TestModelPatch:
assert not masked_indices[padding_positions].any() assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all() assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask(self): def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
"""Test bidirectional attention mask creation.""" """Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask) mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids
)
# Should be all-to-all attention # Should be all-to-all attention
expected_shape = (1, 1, 4, 4) expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape assert mask.shape == expected_shape
assert mask.all() assert mask.all()
def test_bidirectional_attention_mask_with_padding(self): def test_bidirectional_attention_mask_with_packing(
"""Test bidirectional attention mask with padding.""" self, diffusion_model_instance
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) ):
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) """Test bidirectional attention mask with sample packing."""
diffusion_model_instance.config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask) mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Padding positions should not attend or be attended to # Check that tokens within same sample can attend to each other
assert not mask[0, 0, 3, :].any() # Padding can't attend to anything # but not across samples
assert not mask[0, 0, :, 3].any() # Nothing can attend to padding assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_patch_model_for_bidirectional_attention(self): def test_compute_loss_basic(self, diffusion_model_instance):
"""Test that model patching works.""" """Test basic loss computation."""
mock_model = Mock() input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
mock_model.config = Mock()
mock_model.config.loss_type = "ForDiffusionLM"
mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000}
mock_model.training = True
original_forward = Mock() # Create mock data for loss computation
mock_model.forward = original_forward vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Patch the model # Create a simple masked indices tensor (mask middle tokens)
patch_model_for_bidirectional_attention(mock_model) masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
# Check that forward method was replaced loss = diffusion_model_instance._compute_diffusion_loss(
assert mock_model.forward != original_forward input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
class TestDiffusionPlugin: def test_compute_loss_with_labels(self, diffusion_model_instance):
"""Test the DiffusionPlugin.""" """Test loss computation with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
def test_plugin_registers_loss_function(self): # Create mock data for loss computation
"""Test that plugin registers diffusion loss function.""" vocab_size = 1000
with patch( seq_len = 5
"axolotl.integrations.diffusion.plugin.register_diffusion_loss", logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
return_value=True,
) as mock_register:
plugin = DiffusionPlugin()
mock_register.assert_called_once()
def test_post_model_load_configuration(self): # Create masked indices that only covers answer tokens
"""Test that post_model_load configures model correctly.""" masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
plugin = DiffusionPlugin() p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
# Mock model and config loss = diffusion_model_instance._compute_diffusion_loss(
mock_model = Mock() input_ids=input_ids,
mock_model.config = Mock() labels=labels,
mock_cfg = Mock() logits=logits,
mock_cfg.eps = 1e-3 masked_indices=masked_indices,
mock_cfg.importance_weighting = True p_mask=p_mask,
mock_cfg.mask_token_id = 32000 )
with patch( # Check that loss is computed
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention" assert isinstance(loss, torch.Tensor)
) as mock_patch: assert loss.requires_grad
result = plugin.post_model_load(mock_cfg, mock_model)
# Check model configuration def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
assert mock_model.config.loss_type == "ForDiffusionLM" """Test loss computation when no tokens are masked."""
assert mock_model.config.diffusion_config is not None input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
assert mock_model.config.diffusion_config["eps"] == 1e-3
# Check model was patched # Create mock data for loss computation
mock_patch.assert_called_once_with(mock_model) vocab_size = 1000
seq_len = 3
logits = torch.randn(1, seq_len, vocab_size)
# Should return the model # No tokens masked
assert result == mock_model masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
def test_post_trainer_create_stores_config(self, diffusion_config): loss = diffusion_model_instance._compute_diffusion_loss(
"""Test that post_trainer_create stores config on trainer.""" input_ids=input_ids,
plugin = DiffusionPlugin() logits=logits,
mock_trainer = Mock() masked_indices=masked_indices,
mock_cfg = Mock() p_mask=p_mask,
)
# Set config attributes # Loss should be zero when no tokens are masked
for attr, value in diffusion_config.model_dump().items(): assert loss.item() == 0.0
setattr(mock_cfg, attr, value) assert loss.requires_grad
plugin.post_trainer_create(mock_cfg, mock_trainer) def test_cache_special_token_ids(self, diffusion_model_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_model_instance._special_token_ids == expected_tokens
# Check that diffusion config was stored on trainer def test_cache_special_token_ids_no_tokenizer(self):
assert hasattr(mock_trainer, "diffusion_config") """Test caching when no tokenizer is available."""
assert mock_trainer.diffusion_config.eps == diffusion_config.eps # Mock the parent model initialization to avoid loading pretrained weights
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
def test_add_callbacks_post_trainer_with_generation_enabled(self): def test_forward_training_mode(self, diffusion_model_instance):
"""Test callback addition when generation is enabled.""" """Test forward pass in training mode."""
plugin = DiffusionPlugin() input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
mock_trainer = Mock() attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation enabled # Mock the parent forward method
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True) with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_output.logits = torch.randn(1, 5, 32000)
mock_forward.return_value = mock_output
with patch( # Set training mode
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback" diffusion_model_instance.training = True
) as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return one callback result = diffusion_model_instance.forward(
assert len(callbacks) == 1 input_ids=input_ids,
mock_callback_class.assert_called_once_with(mock_trainer) attention_mask=attention_mask,
return_dict=True
)
def test_add_callbacks_post_trainer_with_generation_disabled(self): # Should call parent forward and compute loss
"""Test callback addition when generation is disabled.""" assert mock_forward.called
plugin = DiffusionPlugin() assert hasattr(result, 'loss')
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation disabled def test_forward_inference_mode(self, diffusion_model_instance):
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False) """Test forward pass in inference mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) # Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_forward.return_value = mock_output
# Should return no callbacks # Set inference mode
assert len(callbacks) == 0 diffusion_model_instance.training = False
result = diffusion_model_instance.forward(
input_ids=input_ids,
return_dict=True
)
class TestLossRegistration: # Should just call parent forward without diffusion processing
"""Test loss function registration.""" assert mock_forward.called
assert result == mock_output
def test_register_diffusion_loss(self):
"""Test that loss function can be registered."""
with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping:
result = register_diffusion_loss()
assert result is True
assert "ForDiffusionLM" in mock_mapping
assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss
def test_register_diffusion_loss_import_error(self):
"""Test fallback when LOSS_MAPPING import fails."""
# Patch the import to raise ImportError
with patch(
"builtins.__import__",
side_effect=ImportError("transformers.loss.loss_utils not found"),
):
result = register_diffusion_loss()
assert result is False