diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md index f79d5a46b..32cf21218 100644 --- a/src/axolotl/integrations/diffusion/README.md +++ b/src/axolotl/integrations/diffusion/README.md @@ -64,11 +64,25 @@ learning_rate: 3e-4 ## Supported Models -Any models that support 4D attention masks should work out of the box. If not, please -create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)! +Currently supported base model types: +- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM` +- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM` + +The plugin automatically creates custom model classes that inherit from the base model +while adding diffusion training capabilities. This provides full compatibility with +HuggingFace's ecosystem for saving, loading, and inference. ## How It Works +### Custom Model Architecture + +The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from +standard HuggingFace models. During training, these models: + +1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps +2. **Use bidirectional attention**: Override causal attention with full bidirectional attention +3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting + ### Random Masking During training, tokens are randomly masked based on a sampled timestep: - Sample timestep `t` uniformly from [0, 1] @@ -76,11 +90,10 @@ During training, tokens are randomly masked based on a sampled timestep: - Randomly mask tokens with probability `p` ### Bidirectional Attention -The plugin uses native 4D attention masks to: -- Enable bidirectional attention without patches -- Allow all tokens to attend to all other tokens -- Maintain proper padding masks -- Work with modern `transformers` models out of the box +The models override causal attention with bidirectional attention: +- Creates 4D attention masks allowing all-to-all attention +- Maintains proper padding and sample packing masks +- Compatible with standard HuggingFace attention implementations ### Diffusion Loss @@ -90,6 +103,22 @@ Loss is computed only on masked tokens with (optional) importance weighting: loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens ``` +### Model Loading and Saving + +The custom models work seamlessly with HuggingFace's AutoModel system: + +```python +from transformers import AutoModel, AutoConfig + +# Load a diffusion model +model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True) + +# Save a diffusion model +model.save_pretrained("path/to/save/diffusion/model") +``` + +During inference, the models behave like standard causal language models. + ## Sample Generation When `generate_samples: true`, the plugin generates samples during training: @@ -115,9 +144,19 @@ The plugin adds several metrics to track diffusion training: - `train/ce_loss`: Unweighted cross-entropy loss - `train/importance_weight_avg`: Average importance weight +## Benefits of Custom Model Approach + +✅ **Type Safety**: Full IDE support and type checking +✅ **HuggingFace Integration**: Works with AutoModel, Hub, pipelines +✅ **Maintainability**: Clean architecture, no monkey patching +✅ **Ecosystem Compatibility**: Standard save/load, PEFT support +✅ **Testing**: Easier to test and debug + ## Limitations -- No flash attention support +- **Model Support**: Currently limited to Llama and Mistral architectures +- **Flash Attention**: Not yet optimized for flash attention +- **Inference Speed**: Bidirectional attention is slower than causal for generation ## References diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py index 942f8051f..bbc578057 100644 --- a/src/axolotl/integrations/diffusion/__init__.py +++ b/src/axolotl/integrations/diffusion/__init__.py @@ -1,6 +1,26 @@ """Diffusion LM training plugin init.""" +from transformers import AutoConfig, AutoModel + from .args import DiffusionArgs +from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig +from .models import LlamaForDiffusionLM, MistralForDiffusionLM from .plugin import DiffusionPlugin -__all__ = ["DiffusionArgs", "DiffusionPlugin"] +# Register custom configurations +AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig) +AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig) + +# Register custom models +AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM) +AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM) + +__all__ = [ + "DiffusionArgs", + "DiffusionPlugin", + "DiffusionConfig", + "LlamaForDiffusionConfig", + "MistralForDiffusionConfig", + "LlamaForDiffusionLM", + "MistralForDiffusionLM", +] diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py index 595bf49a1..8994db3a0 100644 --- a/src/axolotl/integrations/diffusion/callbacks.py +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -26,29 +26,31 @@ class DiffusionGenerationCallback(TrainerCallback): **kwargs, ): """Generate samples at specified intervals.""" + config = getattr(self.trainer, 'diffusion_config', self.trainer.args) + if ( state.global_step > 0 - and state.global_step % self.trainer.config.generation_interval == 0 + and state.global_step % config.get('generation_interval', 100) == 0 ): # Use eval dataloader if available, otherwise use train dataloader if ( hasattr(self.trainer, "eval_dataset") and self.trainer.eval_dataset is not None ): - dataloader = self.trainer.callback_handler.eval_dataloader + dataloader = self.trainer.get_eval_dataloader() else: - dataloader = self.trainer.callback_handler.train_dataloader + dataloader = self.trainer.get_train_dataloader() # Generate samples samples = generate_samples( model=self.trainer.model, tokenizer=self.trainer.tokenizer, dataloader=dataloader, - num_generation_samples=self.trainer.config.num_generation_samples, - max_length=self.trainer.config.generation_max_length, - num_diffusion_steps=self.trainer.config.generation_steps, - temperature=self.trainer.config.generation_temperature, - mask_token_id=self.trainer.config.mask_token_id, + num_generation_samples=config.get('num_generation_samples', 3), + max_length=config.get('generation_max_length', 256), + num_diffusion_steps=config.get('generation_steps', 10), + temperature=config.get('generation_temperature', 1.0), + mask_token_id=config.get('mask_token_id', 32000), ) # Log samples @@ -81,7 +83,8 @@ class DiffusionGenerationCallback(TrainerCallback): LOG.info("=" * 60) - if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero: + config = getattr(self.trainer, 'diffusion_config', self.trainer.args) + if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero: if wandb.run is not None: wandb.log( { diff --git a/src/axolotl/integrations/diffusion/configuration.py b/src/axolotl/integrations/diffusion/configuration.py new file mode 100644 index 000000000..8b714bc78 --- /dev/null +++ b/src/axolotl/integrations/diffusion/configuration.py @@ -0,0 +1,71 @@ +"""Configuration classes for diffusion language models.""" + +from transformers import LlamaConfig, MistralConfig + + +class LlamaForDiffusionConfig(LlamaConfig): + """Configuration class for Llama models with diffusion training.""" + + model_type = "llama_diffusion" + + def __init__( + self, + mask_token_id: int = 32000, + eps: float = 1e-3, + importance_weighting: bool = False, + sample_packing: bool = False, + min_mask_ratio: float = 0.0, + max_mask_ratio: float = 1.0, + noise_schedule: str = "linear", + **kwargs, + ): + super().__init__(**kwargs) + + # Diffusion-specific parameters + self.mask_token_id = mask_token_id + self.eps = eps + self.importance_weighting = importance_weighting + self.sample_packing = sample_packing + self.min_mask_ratio = min_mask_ratio + self.max_mask_ratio = max_mask_ratio + self.noise_schedule = noise_schedule + + +class MistralForDiffusionConfig(MistralConfig): + """Configuration class for Mistral models with diffusion training.""" + + model_type = "mistral_diffusion" + + def __init__( + self, + mask_token_id: int = 32000, + eps: float = 1e-3, + importance_weighting: bool = False, + sample_packing: bool = False, + min_mask_ratio: float = 0.0, + max_mask_ratio: float = 1.0, + noise_schedule: str = "linear", + **kwargs, + ): + super().__init__(**kwargs) + + # Diffusion-specific parameters + self.mask_token_id = mask_token_id + self.eps = eps + self.importance_weighting = importance_weighting + self.sample_packing = sample_packing + self.min_mask_ratio = min_mask_ratio + self.max_mask_ratio = max_mask_ratio + self.noise_schedule = noise_schedule + + +# Keep the base class for backward compatibility but mark as deprecated +class DiffusionConfig(LlamaForDiffusionConfig): + """ + Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead. + """ + + model_type = "diffusion" + + def __init__(self, **kwargs): + super().__init__(**kwargs) \ No newline at end of file diff --git a/src/axolotl/integrations/diffusion/models.py b/src/axolotl/integrations/diffusion/models.py new file mode 100644 index 000000000..8679e5ae3 --- /dev/null +++ b/src/axolotl/integrations/diffusion/models.py @@ -0,0 +1,426 @@ +"""Custom model classes for diffusion language models.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import LlamaForCausalLM, MistralForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig + + +class DiffusionModelMixin: + """Mixin class providing diffusion functionality to language models.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._special_token_ids = None + + def _cache_special_token_ids(self, tokenizer=None): + """Cache special token IDs to avoid repeated tokenizer access.""" + if tokenizer is None: + self._special_token_ids = set() + return + + special_tokens = set() + + if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: + special_tokens.add(tokenizer.bos_token_id) + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + special_tokens.add(tokenizer.eos_token_id) + if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + special_tokens.add(tokenizer.pad_token_id) + + self._special_token_ids = special_tokens + + def _forward_process( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + eps: float = 1e-3, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward noising process. A timestep is sampled along the process, and tokens are + masked with probability determined by the configured noise schedule. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + eps: Small epsilon value for minimum masking probability. + + Returns: + noisy_batch: Input with some tokens masked. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Sample random timesteps for each sample in batch + t = torch.rand(batch_size, device=device) + + # Calculate masking probability with epsilon + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens if attention_mask is provided + if attention_mask is not None: + valid_mask = attention_mask.bool() + p_mask = p_mask * valid_mask.float() + + # Create mask to exclude special tokens + special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) + if self._special_token_ids: + for token_id in self._special_token_ids: + special_token_mask |= input_ids == token_id + + # Create random mask based on p_mask + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + masked_indices = masked_indices & ~special_token_mask + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask + + # Create masked input + mask_token_id = self.config.mask_token_id + noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) + + return noisy_batch, masked_indices, p_mask + + def _create_bidirectional_attention_mask( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Create bidirectional attention mask to override default causal masking. Handles + sample-packed sequences where different samples are identified by different + attention mask values. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len] + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + if attention_mask is None or not self.config.sample_packing: + return torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device + ) + + # Create attention mask by comparing sample IDs element-wise + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + + # Add head dimension: [batch_size, 1, seq_len, seq_len] + bidirectional_mask = bidirectional_mask.unsqueeze(1) + + return bidirectional_mask + + def _compute_diffusion_loss( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + logits: torch.Tensor | None = None, + masked_indices: torch.Tensor | None = None, + p_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Compute diffusion loss given logits and masking information. + + Args: + input_ids: Ground truth token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + logits: Model logits [batch_size, seq_len, vocab_size]. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + + Returns: + loss: Cross-entropy loss. + """ + if masked_indices.sum() > 0: + valid_indices = torch.where(masked_indices) + batch_indices, seq_indices = valid_indices + + masked_logits = logits[batch_indices, seq_indices] + masked_targets = input_ids[batch_indices, seq_indices] + masked_p_mask = p_mask[batch_indices, seq_indices] + + # Compute cross-entropy loss without reduction + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + if self.config.importance_weighting: + masked_p_mask = masked_p_mask.float() + weighted_loss = token_loss / masked_p_mask + else: + weighted_loss = token_loss + + # Final loss: sum weighted losses, normalize + if labels is not None: + # For SFT data: normalize by answer length per sample + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] + + # Get batch indices for masked tokens + masked_batch_indices = batch_indices + + # Sum losses per sample and divide by answer length + loss_per_sample = torch.zeros( + input_ids.shape[0], device=input_ids.device + ) + for i in range(input_ids.shape[0]): + sample_mask = masked_batch_indices == i + if sample_mask.sum() > 0: + sample_loss = weighted_loss[sample_mask].sum() + loss_per_sample[i] = sample_loss / answer_lengths[i] + + loss = loss_per_sample.mean() + else: + # Original normalization for non-SFT data + loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) + else: + loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + + return loss + + +class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM): + """ + Llama model for diffusion language modeling. + + This model extends LlamaForCausalLM with diffusion training capabilities, + including bidirectional attention and forward diffusion process. + """ + + config_class = LlamaForDiffusionConfig + + def __init__(self, config): + super().__init__(config) + + # Initialize diffusion-specific attributes + self._special_token_ids = None + + # Initialize weights and apply final processing + self.post_init() + + def set_tokenizer(self, tokenizer): + """Set tokenizer for special token handling.""" + self._cache_special_token_ids(tokenizer) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Forward pass with diffusion training logic. + + During training, applies forward diffusion process and bidirectional attention. + During inference, behaves like standard causal language model. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and input_ids is not None: + # Apply diffusion process during training + original_input_ids = input_ids.clone() + + # Apply forward process to get noisy input + noisy_input_ids, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, labels, self.config.eps + ) + + # Create bidirectional attention mask + bidirectional_attention_mask = self._create_bidirectional_attention_mask( + input_ids, attention_mask + ) + + # Forward pass with noisy input and bidirectional attention + outputs = super().forward( + input_ids=noisy_input_ids, + attention_mask=bidirectional_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=None, # Don't use standard loss computation + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Compute diffusion loss + loss = self._compute_diffusion_loss( + original_input_ids, + attention_mask, + labels, + outputs.logits, + masked_indices, + p_mask, + ) + + if return_dict: + outputs.loss = loss + return outputs + else: + return (loss,) + outputs[1:] + else: + # Standard forward pass for inference + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + +class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM): + """ + Mistral model for diffusion language modeling. + + This model extends MistralForCausalLM with diffusion training capabilities, + including bidirectional attention and forward diffusion process. + """ + + config_class = MistralForDiffusionConfig + + def __init__(self, config): + super().__init__(config) + + # Initialize diffusion-specific attributes + self._special_token_ids = None + + # Initialize weights and apply final processing + self.post_init() + + def set_tokenizer(self, tokenizer): + """Set tokenizer for special token handling.""" + self._cache_special_token_ids(tokenizer) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Forward pass with diffusion training logic. + + During training, applies forward diffusion process and bidirectional attention. + During inference, behaves like standard causal language model. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and input_ids is not None: + # Apply diffusion process during training + original_input_ids = input_ids.clone() + + # Apply forward process to get noisy input + noisy_input_ids, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, labels, self.config.eps + ) + + # Create bidirectional attention mask + bidirectional_attention_mask = self._create_bidirectional_attention_mask( + input_ids, attention_mask + ) + + # Forward pass with noisy input and bidirectional attention + outputs = super().forward( + input_ids=noisy_input_ids, + attention_mask=bidirectional_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=None, # Don't use standard loss computation + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Compute diffusion loss + loss = self._compute_diffusion_loss( + original_input_ids, + attention_mask, + labels, + outputs.logits, + masked_indices, + p_mask, + ) + + if return_dict: + outputs.loss = loss + return outputs + else: + return (loss,) + outputs[1:] + else: + # Standard forward pass for inference + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) \ No newline at end of file diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py index c31f48b03..277acbcce 100644 --- a/src/axolotl/integrations/diffusion/plugin.py +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -1,13 +1,20 @@ """Diffusion LM training plugin for Axolotl.""" +from typing import TYPE_CHECKING + from peft import PeftModel -from transformers import PreTrainedModel +from transformers import AutoConfig, AutoModel, PreTrainedModel from axolotl.integrations.base import BasePlugin from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -from .trainer import DiffusionTrainer +from .callbacks import DiffusionGenerationCallback +from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig +from .models import LlamaForDiffusionLM, MistralForDiffusionLM + +if TYPE_CHECKING: + from transformers import Trainer LOG = get_logger(__name__) @@ -28,14 +35,64 @@ class DiffusionPlugin(BasePlugin): """Returns the pydantic model for LLaDA plugin arguments.""" return "axolotl.integrations.diffusion.DiffusionArgs" + def pre_model_load(self, cfg: DictDefault): + """Configure model loading to use diffusion model classes.""" + # Map base model types to diffusion equivalents + base_model_type = cfg.get("model_type") + + if base_model_type == "llama": + # Create diffusion config from base config + diffusion_config = LlamaForDiffusionConfig( + mask_token_id=getattr(cfg, "mask_token_id", 32000), + eps=getattr(cfg, "eps", 1e-3), + importance_weighting=getattr(cfg, "importance_weighting", False), + sample_packing=getattr(cfg, "sample_packing", False), + min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0), + max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0), + noise_schedule=getattr(cfg, "noise_schedule", "linear"), + ) + + # Override model type for loading + cfg.model_type = "llama_diffusion" + + elif base_model_type == "mistral": + # Create diffusion config from base config + diffusion_config = MistralForDiffusionConfig( + mask_token_id=getattr(cfg, "mask_token_id", 32000), + eps=getattr(cfg, "eps", 1e-3), + importance_weighting=getattr(cfg, "importance_weighting", False), + sample_packing=getattr(cfg, "sample_packing", False), + min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0), + max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0), + noise_schedule=getattr(cfg, "noise_schedule", "linear"), + ) + + # Override model type for loading + cfg.model_type = "mistral_diffusion" + else: + LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}") + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): - """Perform actions after model is loaded.""" + """Configure model after loading.""" self.cfg = cfg + + # Set tokenizer on diffusion models for special token handling + if hasattr(model, "set_tokenizer"): + # Get tokenizer from cfg if available + tokenizer = getattr(cfg, "tokenizer", None) + if tokenizer is not None: + model.set_tokenizer(tokenizer) - def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: - """Return custom trainer class for diffusion training.""" - return DiffusionTrainer - - def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): - """Configure trainer after creation.""" - trainer.set_config(cfg) + def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"): + """Add diffusion-specific callbacks after trainer creation.""" + callbacks = [] + + # Store diffusion config on trainer for callbacks + trainer.diffusion_config = cfg + + # Add generation callback if enabled + if cfg.get("generate_samples", False): + generation_callback = DiffusionGenerationCallback(trainer) + callbacks.append(generation_callback) + + return callbacks diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py deleted file mode 100644 index dc62035d5..000000000 --- a/src/axolotl/integrations/diffusion/trainer.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Custom trainer for diffusion LM training.""" - -from typing import Any, Literal - -import torch -import torch.nn.functional as F -from torch import nn - -from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import get_logger - -from .callbacks import DiffusionGenerationCallback - -LOG = get_logger(__name__) - - -class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors - """Custom trainer for diffusion LM training that overrides loss computation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.config = None - self._special_token_ids = None - - def set_config(self, config: DictDefault): - """Set config for diffusion training.""" - self.config = config - self._cache_special_token_ids() - - if config.generate_samples: - generation_callback = DiffusionGenerationCallback(self) - self.add_callback(generation_callback) - - def compute_loss( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor], - return_outputs: bool = False, - num_items_in_batch: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - """Override compute_loss to use diffusion loss.""" - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask") - labels = inputs.get("labels") - - if input_ids is None: - raise ValueError("input_ids is required for diffusion training") - - loss, outputs = self._compute_diffusion_loss( - model, input_ids, attention_mask, labels - ) - - if return_outputs: - return loss, outputs - return loss - - def _cache_special_token_ids(self): - """Cache special token IDs to avoid repeated tokenizer access.""" - if self.processing_class is None: - self._special_token_ids = set() - return - - tokenizer = self.processing_class - special_tokens = set() - - if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: - special_tokens.add(tokenizer.bos_token_id) - if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: - special_tokens.add(tokenizer.eos_token_id) - if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: - special_tokens.add(tokenizer.pad_token_id) - - self._special_token_ids = special_tokens - - @torch.compile - def _forward_process( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - eps: float = 1e-3, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward noising process. A timestep is sampled along the process, and tokens are - masked with probability determined by the configured noise schedule. - - Args: - input_ids: Input token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len]. - labels: Labels for SFT training [batch_size, seq_len]. - eps: Small epsilon value for minimum masking probability. - - Returns: - noisy_batch: Input with some tokens masked. - masked_indices: Boolean mask indicating which tokens were masked. - p_mask: Masking probabilities for each token [batch_size, seq_len]. - """ - batch_size, seq_len = input_ids.shape - device = input_ids.device - - # Sample random timesteps for each sample in batch - t = torch.rand(batch_size, device=device) - - # Calculate masking probability with epsilon - p_mask = (1 - eps) * t + eps # [batch_size] - p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] - - # Don't mask padding tokens if attention_mask is provided - if attention_mask is not None: - valid_mask = attention_mask.bool() - p_mask = p_mask * valid_mask.float() - - # Create mask to exclude special tokens - special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) - if self._special_token_ids: - for token_id in self._special_token_ids: - special_token_mask |= input_ids == token_id - - # Create random mask based on p_mask - masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask - masked_indices = masked_indices & ~special_token_mask - if attention_mask is not None: - masked_indices = masked_indices & attention_mask.bool() - - # For SFT data, only mask answer tokens - if labels is not None: - answer_mask = labels != -100 - masked_indices = masked_indices & answer_mask - - # Create masked input - mask_token_id = self.config.mask_token_id - noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) - - return noisy_batch, masked_indices, p_mask - - @torch.compile - def _create_bidirectional_attention_mask( - self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Create bidirectional attention mask to override default causal masking. Handles - sample-packed sequences where different samples are identified by different - attention mask values. - - Args: - input_ids: Input token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len] - - Returns: - bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. - """ - batch_size, seq_len = input_ids.shape - device = input_ids.device - - if attention_mask is None or not self.config.sample_packing: - return torch.ones( - batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device - ) - - # Create attention mask by comparing sample IDs element-wise - mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] - mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] - - # Tokens can attend to each other if they have the same non-zero sample ID - bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) - - # Add head dimension: [batch_size, 1, seq_len, seq_len] - bidirectional_mask = bidirectional_mask.unsqueeze(1) - - return bidirectional_mask - - def _compute_diffusion_loss( - self, - model: nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | Any]: - """ - Compute diffusion loss. - - Args: - model: The model to compute loss for. - input_ids: Ground truth token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len]. - labels: Labels for SFT training [batch_size, seq_len]. - - Returns: - loss: Cross-entropy loss. - metrics: Dictionary of metrics. - """ - # Apply forward process - noisy_batch, masked_indices, p_mask = self._forward_process( - input_ids, attention_mask, labels, self.config.eps - ) - - # Create bidirectional attention mask - bidirectional_mask = self._create_bidirectional_attention_mask( - input_ids, attention_mask - ) - - # Forward pass - outputs = model( - input_ids=noisy_batch, - attention_mask=bidirectional_mask, - ) - logits = outputs.logits - - if masked_indices.sum() > 0: - valid_indices = torch.where(masked_indices) - batch_indices, seq_indices = valid_indices - - masked_logits = logits[batch_indices, seq_indices] - masked_targets = input_ids[batch_indices, seq_indices] - masked_p_mask = p_mask[batch_indices, seq_indices] - - # Compute cross-entropy loss without reduction - token_loss = F.cross_entropy( - masked_logits.float(), masked_targets, reduction="none" - ) - - if self.config.importance_weighting: - masked_p_mask = masked_p_mask.float() - weighted_loss = token_loss / masked_p_mask - else: - weighted_loss = token_loss - - # Final loss: sum weighted losses, normalize - if labels is not None: - # For SFT data: normalize by answer length per sample - answer_mask = labels != -100 - answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] - - # Get batch indices for masked tokens - masked_batch_indices = batch_indices - - # Sum losses per sample and divide by answer length - loss_per_sample = torch.zeros( - input_ids.shape[0], device=input_ids.device - ) - for i in range(input_ids.shape[0]): - sample_mask = masked_batch_indices == i - if sample_mask.sum() > 0: - sample_loss = weighted_loss[sample_mask].sum() - loss_per_sample[i] = sample_loss / answer_lengths[i] - - loss = loss_per_sample.mean() - else: - # Original normalization for non-SFT data - loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) - - ce_loss = token_loss.mean() - - # Compute accuracy on masked tokens - with torch.no_grad(): - pred_tokens = masked_logits.argmax(dim=-1) - accuracy = (pred_tokens == masked_targets).float().mean() - else: - loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) - accuracy = torch.tensor(0.0, device=input_ids.device) - ce_loss = torch.tensor(0.0, device=input_ids.device) - masked_p_mask = torch.tensor(1.0, device=input_ids.device) - - metrics = { - "loss": loss.item(), - "accuracy": accuracy.item(), - "mask_ratio": masked_indices.float().mean().item(), - "num_masked_tokens": (masked_indices.sum().item(), "sum"), - "avg_p_mask": p_mask[masked_indices].mean().item(), - "ce_loss": ce_loss.item(), - } - if self.config.importance_weighting: - metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() - - train_eval: Literal["train", "eval"] = "train" if model.training else "eval" - self.store_metrics(metrics, train_eval=train_eval) - - return loss, outputs diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py index 583597238..afe6df946 100644 --- a/tests/integrations/test_diffusion.py +++ b/tests/integrations/test_diffusion.py @@ -1,13 +1,14 @@ -"""Tests for diffusion trainer integration.""" +"""Tests for diffusion model integration.""" # pylint: disable=redefined-outer-name,protected-access -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest import torch -from axolotl.integrations.diffusion.trainer import DiffusionTrainer +from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig +from axolotl.integrations.diffusion.models import LlamaForDiffusionLM from axolotl.utils.dict import DictDefault @@ -24,37 +25,44 @@ def mock_tokenizer(): @pytest.fixture def diffusion_config(): """Create a diffusion config.""" - return DictDefault( - { - "mask_token_id": 32000, - "eps": 1e-3, - "importance_weighting": False, - "sample_packing": False, - } + return LlamaForDiffusionConfig( + mask_token_id=32000, + eps=1e-3, + importance_weighting=False, + sample_packing=False, + # Basic llama config fields - smaller for testing + vocab_size=1000, + hidden_size=256, + intermediate_size=512, + num_hidden_layers=2, + num_attention_heads=4, ) @pytest.fixture -def diffusion_trainer_instance(mock_tokenizer, diffusion_config): - """Create a diffusion trainer instance for testing methods directly.""" - # Create a minimal trainer instance just for testing methods - trainer = object.__new__(DiffusionTrainer) # Bypass __init__ - trainer.config = diffusion_config - trainer._special_token_ids = {0, 1, 2} # pad, bos, eos - trainer.processing_class = mock_tokenizer - trainer.store_metrics = Mock() # Mock metrics storage - return trainer +def diffusion_model_instance(mock_tokenizer, diffusion_config): + """Create a diffusion model instance for testing methods directly.""" + # Create a minimal model instance for testing + model = object.__new__(LlamaForDiffusionLM) + model.config = diffusion_config + model._special_token_ids = {0, 1, 2} # pad, bos, eos + model.training = True + + # Set tokenizer + model.set_tokenizer(mock_tokenizer) + + return model -class TestDiffusionTrainer: - """Test the DiffusionTrainer class.""" +class TestDiffusionModel: + """Test the DiffusionModel class.""" - def test_forward_process_basic(self, diffusion_trainer_instance): + def test_forward_process_basic(self, diffusion_model_instance): """Test basic forward process without labels.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) noisy_batch, masked_indices, p_mask = ( - diffusion_trainer_instance._forward_process(input_ids, eps=0.1) + diffusion_model_instance._forward_process(input_ids, eps=0.1) ) # Check shapes @@ -67,18 +75,18 @@ class TestDiffusionTrainer: assert not masked_indices[special_token_positions].any() # Check that mask token is applied - mask_token_id = diffusion_trainer_instance._config.mask_token_id + mask_token_id = diffusion_model_instance.config.mask_token_id masked_positions = masked_indices if masked_positions.any(): assert (noisy_batch[masked_positions] == mask_token_id).all() - def test_forward_process_with_labels(self, diffusion_trainer_instance): + def test_forward_process_with_labels(self, diffusion_model_instance): """Test forward process with SFT labels.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) noisy_batch, masked_indices, p_mask = ( - diffusion_trainer_instance._forward_process( + diffusion_model_instance._forward_process( input_ids, labels=labels, eps=0.1 ) ) @@ -100,12 +108,12 @@ class TestDiffusionTrainer: # Verify that masked_indices respects the answer mask assert not masked_indices[non_answer_mask].any() - def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): + def test_forward_process_with_attention_mask(self, diffusion_model_instance): """Test forward process with attention mask.""" input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) - _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( + _, masked_indices, p_mask = diffusion_model_instance._forward_process( input_ids, attention_mask=attention_mask, eps=0.1 ) @@ -114,11 +122,11 @@ class TestDiffusionTrainer: assert not masked_indices[padding_positions].any() assert (p_mask[padding_positions] == 0).all() - def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): + def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance): """Test bidirectional attention mask without sample packing.""" input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) - mask = diffusion_trainer_instance._create_bidirectional_attention_mask( + mask = diffusion_model_instance._create_bidirectional_attention_mask( input_ids ) @@ -128,15 +136,15 @@ class TestDiffusionTrainer: assert mask.all() def test_bidirectional_attention_mask_with_packing( - self, diffusion_trainer_instance + self, diffusion_model_instance ): """Test bidirectional attention mask with sample packing.""" - diffusion_trainer_instance._config.sample_packing = True + diffusion_model_instance.config.sample_packing = True input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) # Sample IDs: first sample (1), second sample (2) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) - mask = diffusion_trainer_instance._create_bidirectional_attention_mask( + mask = diffusion_model_instance._create_bidirectional_attention_mask( input_ids, attention_mask ) @@ -148,124 +156,135 @@ class TestDiffusionTrainer: assert not mask[0, 0, 2, 4].item() assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other - def test_compute_loss_basic(self, diffusion_trainer_instance): + def test_compute_loss_basic(self, diffusion_model_instance): """Test basic loss computation.""" - # Mock model that returns logits - mock_model = Mock() - mock_outputs = Mock() + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + # Create mock data for loss computation vocab_size = 1000 seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) - mock_model.return_value = mock_outputs - mock_model.training = True + logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + + # Create a simple masked indices tensor (mask middle tokens) + masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool) + p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float) - input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) - - loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids + loss = diffusion_model_instance._compute_diffusion_loss( + input_ids=input_ids, + logits=logits, + masked_indices=masked_indices, + p_mask=p_mask, ) # Check that loss is computed assert isinstance(loss, torch.Tensor) assert loss.requires_grad - assert outputs == mock_outputs - # Check that metrics were stored - diffusion_trainer_instance.store_metrics.assert_called_once() - - def test_compute_loss_with_labels(self, diffusion_trainer_instance): + def test_compute_loss_with_labels(self, diffusion_model_instance): """Test loss computation with SFT labels.""" - # Mock model - mock_model = Mock() - mock_outputs = Mock() - vocab_size = 1000 - seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) - mock_model.return_value = mock_outputs - mock_model.training = True - input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + # Create mock data for loss computation + vocab_size = 1000 + seq_len = 5 + logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + + # Create masked indices that only covers answer tokens + masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool) + p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float) - loss, _ = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids, labels=labels + loss = diffusion_model_instance._compute_diffusion_loss( + input_ids=input_ids, + labels=labels, + logits=logits, + masked_indices=masked_indices, + p_mask=p_mask, ) # Check that loss is computed assert isinstance(loss, torch.Tensor) assert loss.requires_grad - # Check that SFT metrics were added - call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] - assert "answer_ratio" in call_args - assert "avg_answer_length" in call_args - - def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): + def test_compute_loss_no_masked_tokens(self, diffusion_model_instance): """Test loss computation when no tokens are masked.""" - # Mock model - mock_model = Mock() - mock_outputs = Mock() + input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + + # Create mock data for loss computation vocab_size = 1000 seq_len = 3 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size) - mock_model.return_value = mock_outputs - mock_model.training = True + logits = torch.randn(1, seq_len, vocab_size) + + # No tokens masked + masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool) + p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float) - # Only special tokens (which won't be masked) - input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) - - loss, _ = diffusion_trainer_instance._compute_diffusion_loss( - mock_model, input_ids + loss = diffusion_model_instance._compute_diffusion_loss( + input_ids=input_ids, + logits=logits, + masked_indices=masked_indices, + p_mask=p_mask, ) # Loss should be zero when no tokens are masked assert loss.item() == 0.0 assert loss.requires_grad - def test_cache_special_token_ids(self, diffusion_trainer_instance): + def test_cache_special_token_ids(self, diffusion_model_instance): """Test caching of special token IDs.""" # Should cache BOS, EOS, PAD tokens expected_tokens = {0, 1, 2} # pad, bos, eos - assert diffusion_trainer_instance._special_token_ids == expected_tokens + assert diffusion_model_instance._special_token_ids == expected_tokens def test_cache_special_token_ids_no_tokenizer(self): """Test caching when no tokenizer is available.""" - trainer = object.__new__(DiffusionTrainer) # Bypass __init__ - trainer.processing_class = None - trainer._cache_special_token_ids() + # Mock the parent model initialization to avoid loading pretrained weights + with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'): + model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM) + model._cache_special_token_ids(None) + assert model._special_token_ids == set() - assert trainer._special_token_ids == set() + def test_forward_training_mode(self, diffusion_model_instance): + """Test forward pass in training mode.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool) + + # Mock the parent forward method + with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward: + mock_output = Mock() + mock_output.logits = torch.randn(1, 5, 32000) + mock_forward.return_value = mock_output + + # Set training mode + diffusion_model_instance.training = True + + result = diffusion_model_instance.forward( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Should call parent forward and compute loss + assert mock_forward.called + assert hasattr(result, 'loss') - def test_main_compute_loss_interface(self, diffusion_trainer_instance): - """Test the main compute_loss interface.""" - # Mock model - mock_model = Mock() - mock_outputs = Mock() - mock_outputs.logits = torch.randn(1, 5, 1000) - mock_model.return_value = mock_outputs - mock_model.training = True - - inputs = { - "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), - "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), - "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), - } - - # Test without return_outputs - loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) - assert isinstance(loss, torch.Tensor) - - # Test with return_outputs - loss, outputs = diffusion_trainer_instance.compute_loss( - mock_model, inputs, return_outputs=True - ) - assert isinstance(loss, torch.Tensor) - assert outputs == mock_outputs - - def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): - """Test that missing input_ids raises ValueError.""" - mock_model = Mock() - inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} - - with pytest.raises(ValueError, match="input_ids is required"): - diffusion_trainer_instance.compute_loss(mock_model, inputs) + def test_forward_inference_mode(self, diffusion_model_instance): + """Test forward pass in inference mode.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + # Mock the parent forward method + with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward: + mock_output = Mock() + mock_forward.return_value = mock_output + + # Set inference mode + diffusion_model_instance.training = False + + result = diffusion_model_instance.forward( + input_ids=input_ids, + return_dict=True + ) + + # Should just call parent forward without diffusion processing + assert mock_forward.called + assert result == mock_output