Compare commits

...

1 Commits

Author SHA1 Message Date
Dan Saunders
1f75287a3a diffusion custom models approach 2025-08-19 04:09:46 +00:00
8 changed files with 779 additions and 423 deletions

View File

@@ -64,11 +64,25 @@ learning_rate: 3e-4
## Supported Models ## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please Currently supported base model types:
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)! - **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
## How It Works ## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking ### Random Masking
During training, tokens are randomly masked based on a sampled timestep: During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1] - Sample timestep `t` uniformly from [0, 1]
@@ -76,11 +90,10 @@ During training, tokens are randomly masked based on a sampled timestep:
- Randomly mask tokens with probability `p` - Randomly mask tokens with probability `p`
### Bidirectional Attention ### Bidirectional Attention
The plugin uses native 4D attention masks to: The models override causal attention with bidirectional attention:
- Enable bidirectional attention without patches - Creates 4D attention masks allowing all-to-all attention
- Allow all tokens to attend to all other tokens - Maintains proper padding and sample packing masks
- Maintain proper padding masks - Compatible with standard HuggingFace attention implementations
- Work with modern `transformers` models out of the box
### Diffusion Loss ### Diffusion Loss
@@ -90,6 +103,22 @@ Loss is computed only on masked tokens with (optional) importance weighting:
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
``` ```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation ## Sample Generation
When `generate_samples: true`, the plugin generates samples during training: When `generate_samples: true`, the plugin generates samples during training:
@@ -115,9 +144,19 @@ The plugin adds several metrics to track diffusion training:
- `train/ce_loss`: Unweighted cross-entropy loss - `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight - `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations ## Limitations
- No flash attention support - **Model Support**: Currently limited to Llama and Mistral architectures
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
## References ## References

View File

@@ -1,6 +1,26 @@
"""Diffusion LM training plugin init.""" """Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin from .plugin import DiffusionPlugin
__all__ = ["DiffusionArgs", "DiffusionPlugin"] # Register custom configurations
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]

View File

@@ -26,29 +26,31 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs, **kwargs,
): ):
"""Generate samples at specified intervals.""" """Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if ( if (
state.global_step > 0 state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0 and state.global_step % config.get('generation_interval', 100) == 0
): ):
# Use eval dataloader if available, otherwise use train dataloader # Use eval dataloader if available, otherwise use train dataloader
if ( if (
hasattr(self.trainer, "eval_dataset") hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None and self.trainer.eval_dataset is not None
): ):
dataloader = self.trainer.callback_handler.eval_dataloader dataloader = self.trainer.get_eval_dataloader()
else: else:
dataloader = self.trainer.callback_handler.train_dataloader dataloader = self.trainer.get_train_dataloader()
# Generate samples # Generate samples
samples = generate_samples( samples = generate_samples(
model=self.trainer.model, model=self.trainer.model,
tokenizer=self.trainer.tokenizer, tokenizer=self.trainer.tokenizer,
dataloader=dataloader, dataloader=dataloader,
num_generation_samples=self.trainer.config.num_generation_samples, num_generation_samples=config.get('num_generation_samples', 3),
max_length=self.trainer.config.generation_max_length, max_length=config.get('generation_max_length', 256),
num_diffusion_steps=self.trainer.config.generation_steps, num_diffusion_steps=config.get('generation_steps', 10),
temperature=self.trainer.config.generation_temperature, temperature=config.get('generation_temperature', 1.0),
mask_token_id=self.trainer.config.mask_token_id, mask_token_id=config.get('mask_token_id', 32000),
) )
# Log samples # Log samples
@@ -81,7 +83,8 @@ class DiffusionGenerationCallback(TrainerCallback):
LOG.info("=" * 60) LOG.info("=" * 60)
if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero: config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if wandb.run is not None: if wandb.run is not None:
wandb.log( wandb.log(
{ {

View File

@@ -0,0 +1,71 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -0,0 +1,426 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,13 +1,20 @@
"""Diffusion LM training plugin for Axolotl.""" """Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel from transformers import AutoConfig, AutoModel, PreTrainedModel
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer from .callbacks import DiffusionGenerationCallback
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -28,14 +35,64 @@ class DiffusionPlugin(BasePlugin):
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded.""" """Configure model after loading."""
self.cfg = cfg self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: # Set tokenizer on diffusion models for special token handling
"""Return custom trainer class for diffusion training.""" if hasattr(model, "set_tokenizer"):
return DiffusionTrainer # Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
"""Configure trainer after creation.""" """Add diffusion-specific callbacks after trainer creation."""
trainer.set_config(cfg) callbacks = []
# Store diffusion config on trainer for callbacks
trainer.diffusion_config = cfg
# Add generation callback if enabled
if cfg.get("generate_samples", False):
generation_callback = DiffusionGenerationCallback(trainer)
callbacks.append(generation_callback)
return callbacks

View File

@@ -1,279 +0,0 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._cache_special_token_ids()
if config.generate_samples:
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
@torch.compile
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
@torch.compile
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass
outputs = model(
input_ids=noisy_batch,
attention_mask=bidirectional_mask,
)
logits = outputs.logits
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
"avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(),
}
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -1,13 +1,14 @@
"""Tests for diffusion trainer integration.""" """Tests for diffusion model integration."""
# pylint: disable=redefined-outer-name,protected-access # pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from axolotl.integrations.diffusion.trainer import DiffusionTrainer from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -24,37 +25,44 @@ def mock_tokenizer():
@pytest.fixture @pytest.fixture
def diffusion_config(): def diffusion_config():
"""Create a diffusion config.""" """Create a diffusion config."""
return DictDefault( return LlamaForDiffusionConfig(
{ mask_token_id=32000,
"mask_token_id": 32000, eps=1e-3,
"eps": 1e-3, importance_weighting=False,
"importance_weighting": False, sample_packing=False,
"sample_packing": False, # Basic llama config fields - smaller for testing
} vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
) )
@pytest.fixture @pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config): def diffusion_model_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly.""" """Create a diffusion model instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods # Create a minimal model instance for testing
trainer = object.__new__(DiffusionTrainer) # Bypass __init__ model = object.__new__(LlamaForDiffusionLM)
trainer.config = diffusion_config model.config = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos model._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer model.training = True
trainer.store_metrics = Mock() # Mock metrics storage
return trainer # Set tokenizer
model.set_tokenizer(mock_tokenizer)
return model
class TestDiffusionTrainer: class TestDiffusionModel:
"""Test the DiffusionTrainer class.""" """Test the DiffusionModel class."""
def test_forward_process_basic(self, diffusion_trainer_instance): def test_forward_process_basic(self, diffusion_model_instance):
"""Test basic forward process without labels.""" """Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = ( noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(input_ids, eps=0.1) diffusion_model_instance._forward_process(input_ids, eps=0.1)
) )
# Check shapes # Check shapes
@@ -67,18 +75,18 @@ class TestDiffusionTrainer:
assert not masked_indices[special_token_positions].any() assert not masked_indices[special_token_positions].any()
# Check that mask token is applied # Check that mask token is applied
mask_token_id = diffusion_trainer_instance._config.mask_token_id mask_token_id = diffusion_model_instance.config.mask_token_id
masked_positions = masked_indices masked_positions = masked_indices
if masked_positions.any(): if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all() assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance): def test_forward_process_with_labels(self, diffusion_model_instance):
"""Test forward process with SFT labels.""" """Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = ( noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process( diffusion_model_instance._forward_process(
input_ids, labels=labels, eps=0.1 input_ids, labels=labels, eps=0.1
) )
) )
@@ -100,12 +108,12 @@ class TestDiffusionTrainer:
# Verify that masked_indices respects the answer mask # Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any() assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): def test_forward_process_with_attention_mask(self, diffusion_model_instance):
"""Test forward process with attention mask.""" """Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process( _, masked_indices, p_mask = diffusion_model_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1 input_ids, attention_mask=attention_mask, eps=0.1
) )
@@ -114,11 +122,11 @@ class TestDiffusionTrainer:
assert not masked_indices[padding_positions].any() assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all() assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
"""Test bidirectional attention mask without sample packing.""" """Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask( mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids input_ids
) )
@@ -128,15 +136,15 @@ class TestDiffusionTrainer:
assert mask.all() assert mask.all()
def test_bidirectional_attention_mask_with_packing( def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance self, diffusion_model_instance
): ):
"""Test bidirectional attention mask with sample packing.""" """Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance._config.sample_packing = True diffusion_model_instance.config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2) # Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask( mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids, attention_mask input_ids, attention_mask
) )
@@ -148,124 +156,135 @@ class TestDiffusionTrainer:
assert not mask[0, 0, 2, 4].item() assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer_instance): def test_compute_loss_basic(self, diffusion_model_instance):
"""Test basic loss computation.""" """Test basic loss computation."""
# Mock model that returns logits
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( # Create mock data for loss computation
mock_model, input_ids vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create a simple masked indices tensor (mask middle tokens)
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
assert outputs == mock_outputs
# Check that metrics were stored def test_compute_loss_with_labels(self, diffusion_model_instance):
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels.""" """Test loss computation with SFT labels."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss( # Create mock data for loss computation
mock_model, input_ids, labels=labels vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create masked indices that only covers answer tokens
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
labels=labels,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
# Check that SFT metrics were added def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked.""" """Test loss computation when no tokens are masked."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
# Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss( # Create mock data for loss computation
mock_model, input_ids vocab_size = 1000
seq_len = 3
logits = torch.randn(1, seq_len, vocab_size)
# No tokens masked
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
) )
# Loss should be zero when no tokens are masked # Loss should be zero when no tokens are masked
assert loss.item() == 0.0 assert loss.item() == 0.0
assert loss.requires_grad assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_trainer_instance): def test_cache_special_token_ids(self, diffusion_model_instance):
"""Test caching of special token IDs.""" """Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens # Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_trainer_instance._special_token_ids == expected_tokens assert diffusion_model_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self): def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available.""" """Test caching when no tokenizer is available."""
trainer = object.__new__(DiffusionTrainer) # Bypass __init__ # Mock the parent model initialization to avoid loading pretrained weights
trainer.processing_class = None with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
trainer._cache_special_token_ids() model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
assert trainer._special_token_ids == set() def test_forward_training_mode(self, diffusion_model_instance):
"""Test forward pass in training mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
def test_main_compute_loss_interface(self, diffusion_trainer_instance): # Mock the parent forward method
"""Test the main compute_loss interface.""" with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
# Mock model mock_output = Mock()
mock_model = Mock() mock_output.logits = torch.randn(1, 5, 32000)
mock_outputs = Mock() mock_forward.return_value = mock_output
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = { # Set training mode
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), diffusion_model_instance.training = True
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
}
# Test without return_outputs result = diffusion_model_instance.forward(
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) input_ids=input_ids,
assert isinstance(loss, torch.Tensor) attention_mask=attention_mask,
return_dict=True
# Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
) )
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): # Should call parent forward and compute loss
"""Test that missing input_ids raises ValueError.""" assert mock_forward.called
mock_model = Mock() assert hasattr(result, 'loss')
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"): def test_forward_inference_mode(self, diffusion_model_instance):
diffusion_trainer_instance.compute_loss(mock_model, inputs) """Test forward pass in inference mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_forward.return_value = mock_output
# Set inference mode
diffusion_model_instance.training = False
result = diffusion_model_instance.forward(
input_ids=input_ids,
return_dict=True
)
# Should just call parent forward without diffusion processing
assert mock_forward.called
assert result == mock_output