Compare commits

..

2 Commits

Author SHA1 Message Date
Dan Saunders
64f349b7bb diffusion alt: custom loss impl 2025-08-18 20:50:34 +00:00
Dan Saunders
260ebe4c93 diffusion alt: custom loss impl 2025-08-18 20:50:20 +00:00
9 changed files with 603 additions and 855 deletions

View File

@@ -64,25 +64,11 @@ learning_rate: 3e-4
## Supported Models
Currently supported base model types:
- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking
During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1]
@@ -90,10 +76,11 @@ During training, tokens are randomly masked based on a sampled timestep:
- Randomly mask tokens with probability `p`
### Bidirectional Attention
The models override causal attention with bidirectional attention:
- Creates 4D attention masks allowing all-to-all attention
- Maintains proper padding and sample packing masks
- Compatible with standard HuggingFace attention implementations
The plugin uses native 4D attention masks to:
- Enable bidirectional attention without patches
- Allow all tokens to attend to all other tokens
- Maintain proper padding masks
- Work with modern `transformers` models out of the box
### Diffusion Loss
@@ -103,22 +90,6 @@ Loss is computed only on masked tokens with (optional) importance weighting:
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
@@ -144,19 +115,9 @@ The plugin adds several metrics to track diffusion training:
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations
- **Model Support**: Currently limited to Llama and Mistral architectures
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
- No flash attention support
## References

View File

@@ -1,26 +1,6 @@
"""Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin
# Register custom configurations
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]
__all__ = ["DiffusionArgs", "DiffusionPlugin"]

View File

@@ -26,31 +26,29 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs,
):
"""Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if (
state.global_step > 0
and state.global_step % config.get('generation_interval', 100) == 0
and state.global_step % self.trainer.config.generation_interval == 0
):
# Use eval dataloader if available, otherwise use train dataloader
if (
hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
dataloader = self.trainer.get_eval_dataloader()
dataloader = self.trainer.callback_handler.eval_dataloader
else:
dataloader = self.trainer.get_train_dataloader()
dataloader = self.trainer.callback_handler.train_dataloader
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
dataloader=dataloader,
num_generation_samples=config.get('num_generation_samples', 3),
max_length=config.get('generation_max_length', 256),
num_diffusion_steps=config.get('generation_steps', 10),
temperature=config.get('generation_temperature', 1.0),
mask_token_id=config.get('mask_token_id', 32000),
num_generation_samples=self.trainer.config.num_generation_samples,
max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=self.trainer.config.generation_steps,
temperature=self.trainer.config.generation_temperature,
mask_token_id=self.trainer.config.mask_token_id,
)
# Log samples
@@ -83,8 +81,7 @@ class DiffusionGenerationCallback(TrainerCallback):
LOG.info("=" * 60)
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero:
if wandb.run is not None:
wandb.log(
{

View File

@@ -1,71 +0,0 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -0,0 +1,115 @@
"""Diffusion LM loss function for integration with transformers LOSS_MAPPING."""
from typing import Optional
import torch
import torch.nn.functional as F
def ForDiffusionLMLoss(
logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
config: Optional[dict] = None,
inputs: Optional[dict] = None,
model: Optional[torch.nn.Module] = None,
**kwargs,
) -> torch.Tensor:
"""
Diffusion Language Modeling loss function.
This function computes cross-entropy loss only on masked tokens using
diffusion info stored by the model patch during forward pass.
Args:
logits: Model predictions [batch_size, seq_len, vocab_size]
labels: Ground truth tokens [batch_size, seq_len]
vocab_size: Size of vocabulary
config: Model configuration (contains diffusion parameters)
inputs: Input batch dictionary (contains input_ids, attention_mask)
model: The model instance (to access stored diffusion info)
**kwargs: Additional arguments
Returns:
loss: Computed diffusion loss
"""
# Get diffusion info stored by model patch
if model is None or not hasattr(model, "_diffusion_info"):
# Fallback to regular causal LM loss if no diffusion info
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
return loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
diffusion_info = model._diffusion_info
original_input_ids = diffusion_info["original_input_ids"]
masked_indices = diffusion_info["masked_indices"]
p_mask = diffusion_info["p_mask"]
# Get diffusion config parameters
diffusion_config = getattr(config, "diffusion_config", {})
importance_weighting = diffusion_config.get("importance_weighting", True)
# Check if we have any masked tokens
if not masked_indices.any():
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Get predictions and targets for masked positions only
masked_logits = logits[masked_indices]
masked_targets = original_input_ids[masked_indices] # Original unmasked tokens
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if importance_weighting:
# Apply importance weighting: 1 / p_mask
masked_p_mask = p_mask.expand_as(masked_indices)[masked_indices]
weighted_loss = token_loss / masked_p_mask
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float()
# Group losses by batch sample
batch_indices = torch.arange(
original_input_ids.shape[0], device=original_input_ids.device
)
batch_indices = batch_indices.unsqueeze(1).expand_as(masked_indices)
masked_batch_indices = batch_indices[masked_indices]
# Sum losses per sample and normalize by answer length
loss_per_sample = torch.zeros(
original_input_ids.shape[0], device=original_input_ids.device
)
for i in range(original_input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.any():
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / max(answer_lengths[i], 1)
loss = loss_per_sample.mean()
else:
# For completion data: simple average
loss = weighted_loss.mean()
else:
# No importance weighting
loss = token_loss.mean()
return loss
def register_diffusion_loss():
"""Register the diffusion loss function in transformers LOSS_MAPPING."""
try:
from transformers.loss.loss_utils import LOSS_MAPPING
LOSS_MAPPING["ForDiffusionLM"] = ForDiffusionLMLoss
return True
except ImportError:
# Fallback for older transformers versions
return False

View File

@@ -0,0 +1,149 @@
"""Model patches for diffusion training."""
import torch
def patch_model_for_bidirectional_attention(model):
"""
Patch model to handle diffusion training with forward process and bidirectional
attention.
This monkey-patches the model's forward method to:
- Apply forward diffusion process (masking) during training
- Use bidirectional attention masks
- Store info for loss computation
"""
original_forward = model.forward
def diffusion_forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs,
):
# Check if this is diffusion training
if (
hasattr(self.config, "loss_type")
and self.config.loss_type == "ForDiffusionLM"
and self.training
):
# Store original input_ids for loss computation
original_input_ids = input_ids.clone()
# Apply forward diffusion process (masking)
diffusion_config = getattr(self.config, "diffusion_config", {})
noisy_input_ids, masked_indices, p_mask = _forward_process(
input_ids, attention_mask, labels, diffusion_config
)
# Use noisy input for model forward
input_ids = noisy_input_ids
# Convert attention mask to bidirectional
if attention_mask is not None:
attention_mask = _create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Store diffusion info in the model for loss computation
self._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
return original_forward(
input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
)
# Replace the forward method
model.forward = diffusion_forward.__get__(model, model.__class__)
def _create_bidirectional_attention_mask(
input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Create bidirectional attention mask from 2D attention mask.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: 2D attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
"""
batch_size, seq_len = input_ids.shape
# Simple bidirectional mask - all tokens can attend to all valid tokens
# Expand 2D mask to 4D: [batch_size, seq_len] -> [batch_size, 1, seq_len, seq_len]
bidirectional_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
bidirectional_mask = bidirectional_mask.expand(batch_size, 1, seq_len, seq_len)
# Apply row-wise masking (padded tokens can't attend to anything)
row_mask = attention_mask.unsqueeze(1).unsqueeze(3) # [B, 1, S, 1]
bidirectional_mask = bidirectional_mask & row_mask
return bidirectional_mask
def _forward_process(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
diffusion_config: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply forward diffusion process (random masking).
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Labels for SFT training [batch_size, seq_len]
diffusion_config: Diffusion configuration dict
Returns:
noisy_input_ids: Input with masked tokens
masked_indices: Boolean mask of which tokens were masked
p_mask: Masking probabilities used
"""
if diffusion_config is None:
diffusion_config = {}
batch_size, seq_len = input_ids.shape
device = input_ids.device
eps = diffusion_config.get("eps", 1e-3)
mask_token_id = diffusion_config.get("mask_token_id", 128002)
# Sample random timesteps for each sample
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask.unsqueeze(1).expand(-1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens
if attention_mask is not None:
p_mask = p_mask * attention_mask.float()
# Create random mask based on p_mask
random_values = torch.rand_like(p_mask)
masked_indices = random_values < p_mask
# Apply attention mask constraints
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens (where labels != -100)
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create noisy input by replacing masked tokens
noisy_input_ids = input_ids.clone()
noisy_input_ids[masked_indices] = mask_token_id
return noisy_input_ids, masked_indices, p_mask

View File

@@ -1,426 +0,0 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,20 +1,16 @@
"""Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .args import DiffusionArgs
from .callbacks import DiffusionGenerationCallback
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
from .loss import register_diffusion_loss
from .model_patch import patch_model_for_bidirectional_attention
LOG = get_logger(__name__)
@@ -31,68 +27,70 @@ class DiffusionPlugin(BasePlugin):
super().__init__()
self.cfg = None
if register_diffusion_loss():
LOG.info("Registered ForDiffusionLM loss function")
else:
LOG.warning(
"Failed to register diffusion loss - older transformers version"
)
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model after loading."""
"""Configure model for diffusion training after loading."""
self.cfg = cfg
# Set tokenizer on diffusion models for special token handling
if hasattr(model, "set_tokenizer"):
# Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
"""Add diffusion-specific callbacks after trainer creation."""
callbacks = []
# Store diffusion config on trainer for callbacks
trainer.diffusion_config = cfg
# Add generation callback if enabled
if cfg.get("generate_samples", False):
# Set loss type for diffusion training
if hasattr(model, "config"):
model.config.loss_type = "ForDiffusionLM"
# Store diffusion config in model config
model.config.diffusion_config = {
"eps": getattr(cfg, "eps", 1e-3),
"importance_weighting": getattr(cfg, "importance_weighting", True),
"mask_token_id": getattr(cfg, "mask_token_id", 128002),
}
LOG.info("Configured model for diffusion training with ForDiffusionLM loss")
# Patch model for bidirectional attention during training
patch_model_for_bidirectional_attention(model)
LOG.info("Applied bidirectional attention patch to model")
return model
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Configure trainer after creation."""
# Create diffusion config from cfg
diffusion_config = DiffusionArgs(
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.1),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 0.9),
num_diffusion_steps=getattr(cfg, "num_diffusion_steps", 128),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", True),
mask_token_id=getattr(cfg, "mask_token_id", 128002),
generate_samples=getattr(cfg, "generate_samples", True),
generation_interval=getattr(cfg, "generation_interval", 100),
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
generation_steps=getattr(cfg, "generation_steps", 128),
generation_temperature=getattr(cfg, "generation_temperature", 0.0),
generation_max_length=getattr(cfg, "generation_max_length", 100),
)
# Store diffusion config on trainer for callbacks to access
trainer.diffusion_config = diffusion_config
LOG.info("Stored diffusion config on trainer")
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
"""Add diffusion generation callback if enabled."""
if (
hasattr(trainer, "diffusion_config")
and trainer.diffusion_config.generate_samples
):
generation_callback = DiffusionGenerationCallback(trainer)
callbacks.append(generation_callback)
return callbacks
LOG.info("Added diffusion generation callback")
return [generation_callback]
return []

View File

@@ -1,4 +1,4 @@
"""Tests for diffusion model integration."""
"""Tests for diffusion trainer integration."""
# pylint: disable=redefined-outer-name,protected-access
@@ -7,114 +7,175 @@ from unittest.mock import Mock, patch
import pytest
import torch
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
from axolotl.integrations.diffusion.args import DiffusionArgs
from axolotl.integrations.diffusion.loss import (
ForDiffusionLMLoss,
register_diffusion_loss,
)
from axolotl.integrations.diffusion.model_patch import (
_create_bidirectional_attention_mask,
_forward_process,
patch_model_for_bidirectional_attention,
)
from axolotl.integrations.diffusion.plugin import DiffusionPlugin
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return LlamaForDiffusionConfig(
mask_token_id=32000,
return DiffusionArgs(
eps=1e-3,
importance_weighting=False,
sample_packing=False,
# Basic llama config fields - smaller for testing
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
mask_token_id=32000,
generate_samples=False,
)
@pytest.fixture
def diffusion_model_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion model instance for testing methods directly."""
# Create a minimal model instance for testing
model = object.__new__(LlamaForDiffusionLM)
model.config = diffusion_config
model._special_token_ids = {0, 1, 2} # pad, bos, eos
def mock_model():
"""Create a mock model."""
model = Mock()
model.config = Mock()
model.config.loss_type = "ForDiffusionLM"
model.config.diffusion_config = {
"eps": 1e-3,
"importance_weighting": False,
"mask_token_id": 32000,
}
model.training = True
# Set tokenizer
model.set_tokenizer(mock_tokenizer)
return model
class TestDiffusionModel:
"""Test the DiffusionModel class."""
class TestDiffusionLoss:
"""Test the ForDiffusionLMLoss function."""
def test_forward_process_basic(self, diffusion_model_instance):
"""Test basic forward process without labels."""
def test_loss_with_diffusion_info(self, mock_model):
"""Test loss computation with stored diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Mock stored diffusion info
original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
masked_indices = torch.tensor(
[[False, True, True, False, False]], dtype=torch.bool
)
p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
# Mock logits
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert loss.item() >= 0
def test_loss_fallback_without_diffusion_info(self, mock_model):
"""Test fallback to causal LM loss when no diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Remove diffusion info to trigger fallback
if hasattr(mock_model, "_diffusion_info"):
delattr(mock_model, "_diffusion_info")
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_loss_no_masked_tokens(self, mock_model):
"""Test loss when no tokens are masked."""
batch_size, seq_len, vocab_size = 1, 3, 1000
# No masked tokens
original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long)
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.tensor([[1, 10, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert loss.item() == 0.0
class TestModelPatch:
"""Test the model patching functionality."""
def test_forward_process_basic(self):
"""Test basic forward process."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process(input_ids, eps=0.1)
noisy_input_ids, masked_indices, p_mask = _forward_process(
input_ids, diffusion_config=diffusion_config
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert noisy_input_ids.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied where masked
if masked_indices.any():
assert (noisy_input_ids[masked_indices] == 32000).all()
# Check that mask token is applied
mask_token_id = diffusion_model_instance.config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_model_instance):
def test_forward_process_with_labels(self):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
_, masked_indices, _ = _forward_process(
input_ids, labels=labels, diffusion_config=diffusion_config
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
def test_forward_process_with_attention_mask(self):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
_, masked_indices, p_mask = diffusion_model_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
_, masked_indices, p_mask = _forward_process(
input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config
)
# Check that padding tokens are not masked
@@ -122,169 +183,153 @@ class TestDiffusionModel:
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
"""Test bidirectional attention mask without sample packing."""
def test_bidirectional_attention_mask(self):
"""Test bidirectional attention mask creation."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids
)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
# Should be all-to-all attention
expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_model_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_model_instance.config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
def test_bidirectional_attention_mask_with_padding(self):
"""Test bidirectional attention mask with padding."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
# Check that tokens within same sample can attend to each other
# but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
# Padding positions should not attend or be attended to
assert not mask[0, 0, 3, :].any() # Padding can't attend to anything
assert not mask[0, 0, :, 3].any() # Nothing can attend to padding
def test_compute_loss_basic(self, diffusion_model_instance):
"""Test basic loss computation."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create a simple masked indices tensor (mask middle tokens)
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
def test_patch_model_for_bidirectional_attention(self):
"""Test that model patching works."""
mock_model = Mock()
mock_model.config = Mock()
mock_model.config.loss_type = "ForDiffusionLM"
mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000}
mock_model.training = True
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
original_forward = Mock()
mock_model.forward = original_forward
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
# Patch the model
patch_model_for_bidirectional_attention(mock_model)
def test_compute_loss_with_labels(self, diffusion_model_instance):
"""Test loss computation with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create masked indices that only covers answer tokens
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
# Check that forward method was replaced
assert mock_model.forward != original_forward
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
labels=labels,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
class TestDiffusionPlugin:
"""Test the DiffusionPlugin."""
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
"""Test loss computation when no tokens are masked."""
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 3
logits = torch.randn(1, seq_len, vocab_size)
# No tokens masked
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
def test_plugin_registers_loss_function(self):
"""Test that plugin registers diffusion loss function."""
with patch(
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
return_value=True,
) as mock_register:
plugin = DiffusionPlugin()
mock_register.assert_called_once()
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
def test_post_model_load_configuration(self):
"""Test that post_model_load configures model correctly."""
plugin = DiffusionPlugin()
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
# Mock model and config
mock_model = Mock()
mock_model.config = Mock()
mock_cfg = Mock()
mock_cfg.eps = 1e-3
mock_cfg.importance_weighting = True
mock_cfg.mask_token_id = 32000
def test_cache_special_token_ids(self, diffusion_model_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_model_instance._special_token_ids == expected_tokens
with patch(
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
) as mock_patch:
result = plugin.post_model_load(mock_cfg, mock_model)
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
# Mock the parent model initialization to avoid loading pretrained weights
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
# Check model configuration
assert mock_model.config.loss_type == "ForDiffusionLM"
assert mock_model.config.diffusion_config is not None
assert mock_model.config.diffusion_config["eps"] == 1e-3
def test_forward_training_mode(self, diffusion_model_instance):
"""Test forward pass in training mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_output.logits = torch.randn(1, 5, 32000)
mock_forward.return_value = mock_output
# Set training mode
diffusion_model_instance.training = True
result = diffusion_model_instance.forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Should call parent forward and compute loss
assert mock_forward.called
assert hasattr(result, 'loss')
# Check model was patched
mock_patch.assert_called_once_with(mock_model)
def test_forward_inference_mode(self, diffusion_model_instance):
"""Test forward pass in inference mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_forward.return_value = mock_output
# Set inference mode
diffusion_model_instance.training = False
result = diffusion_model_instance.forward(
input_ids=input_ids,
return_dict=True
)
# Should just call parent forward without diffusion processing
assert mock_forward.called
assert result == mock_output
# Should return the model
assert result == mock_model
def test_post_trainer_create_stores_config(self, diffusion_config):
"""Test that post_trainer_create stores config on trainer."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Set config attributes
for attr, value in diffusion_config.model_dump().items():
setattr(mock_cfg, attr, value)
plugin.post_trainer_create(mock_cfg, mock_trainer)
# Check that diffusion config was stored on trainer
assert hasattr(mock_trainer, "diffusion_config")
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
def test_add_callbacks_post_trainer_with_generation_enabled(self):
"""Test callback addition when generation is enabled."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation enabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
with patch(
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
) as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return one callback
assert len(callbacks) == 1
mock_callback_class.assert_called_once_with(mock_trainer)
def test_add_callbacks_post_trainer_with_generation_disabled(self):
"""Test callback addition when generation is disabled."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation disabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return no callbacks
assert len(callbacks) == 0
class TestLossRegistration:
"""Test loss function registration."""
def test_register_diffusion_loss(self):
"""Test that loss function can be registered."""
with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping:
result = register_diffusion_loss()
assert result is True
assert "ForDiffusionLM" in mock_mapping
assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss
def test_register_diffusion_loss_import_error(self):
"""Test fallback when LOSS_MAPPING import fails."""
# Patch the import to raise ImportError
with patch(
"builtins.__import__",
side_effect=ImportError("transformers.loss.loss_utils not found"),
):
result = register_diffusion_loss()
assert result is False