From 479a454ae387e236d66c88f0a29c4f50877339e8 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 14 Aug 2025 16:11:37 -0400 Subject: [PATCH] fixes + improvements --- .../llama-3/diffusion-3.2-1b-pretrain.yaml | 8 +- src/axolotl/integrations/diffusion/args.py | 4 + src/axolotl/integrations/diffusion/trainer.py | 135 +++++++++--------- 3 files changed, 78 insertions(+), 69 deletions(-) diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml index 7084216bb..95d820cca 100644 --- a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -22,8 +22,8 @@ importance_weighting: true output_dir: ./outputs/model-out sequence_len: 512 -sample_packing: true -eval_sample_packing: true +sample_packing: false +eval_sample_packing: false gradient_accumulation_steps: 8 micro_batch_size: 4 @@ -51,8 +51,8 @@ eval_steps: 1000 special_tokens: pad_token: "<|end_of_text|>" -wandb_project: -wandb_entity: +wandb_project: diffusion-plugin +wandb_entity: axolotl-ai wandb_watch: wandb_name: wandb_log_model: diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py index 639c85055..f01db087c 100644 --- a/src/axolotl/integrations/diffusion/args.py +++ b/src/axolotl/integrations/diffusion/args.py @@ -41,3 +41,7 @@ class DiffusionArgs(BaseModel): default=True, description="Apply importance weighting to loss based on masking probability", ) + mask_token_id: int = Field( + default=128002, + description="Token ID to use for masking. Default is 128002 (<|reserved_special_token_0|> for Llama 3.2)", + ) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index bb178341f..160b5692b 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -1,10 +1,8 @@ """Custom trainer for diffusion LM training.""" -from typing import Dict, Optional, Tuple, Union - import torch import torch.nn.functional as F -from transformers import PreTrainedModel +from torch import nn from axolotl.core.trainers.base import AxolotlTrainer from axolotl.utils.dict import DictDefault @@ -19,17 +17,37 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.config = None + self._special_token_ids = None def set_config(self, config: DictDefault): """Set config for diffusion training.""" self.config = config + self._cache_special_token_ids() + + def _cache_special_token_ids(self): + """Cache special token IDs to avoid repeated tokenizer access.""" + if self.processing_class is None: + self._special_token_ids = set() + return + + tokenizer = self.processing_class + special_tokens = set() + + if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: + special_tokens.add(tokenizer.bos_token_id) + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + special_tokens.add(tokenizer.eos_token_id) + if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + special_tokens.add(tokenizer.pad_token_id) + + self._special_token_ids = special_tokens def forward_process( self, input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, eps: float = 1e-3, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward noising process. A timestep is sampled along the process, and tokens are masked with probability determined by the configured noise schedule. @@ -59,19 +77,20 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors valid_mask = attention_mask.bool() p_mask = p_mask * valid_mask.float() - # Create random mask based on probability + # Create mask to exclude special tokens (BOS, EOS, PAD) using cached IDs + special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) + if self._special_token_ids: + for token_id in self._special_token_ids: + special_token_mask |= input_ids == token_id + + # Create random mask based on probability, excluding special tokens masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + masked_indices = masked_indices & ~special_token_mask if attention_mask is not None: masked_indices = masked_indices & attention_mask.bool() - # Get tokenizer - tokenizer = self.processing_class - assert tokenizer is not None, "Tokenizer not available on Trainer object." - - # Get mask token ID - mask_token_id = getattr(tokenizer, "mask_token_id", None) - if mask_token_id is None: - mask_token_id = getattr(tokenizer, "unk_token_id", None) + # Get mask token ID from config + mask_token_id = self.config.mask_token_id # Create masked input using configured mask token noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) @@ -79,49 +98,47 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors return noisy_batch, masked_indices, p_mask def create_bidirectional_attention_mask( - self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None ) -> torch.Tensor: """ Create bidirectional attention mask to override default causal masking. + Handles sample-packed sequences where different samples are identified + by different attention mask values. Args: input_ids: Input token ids [batch_size, seq_len]. - attention_mask: Attention mask [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len] Returns: bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. """ batch_size, seq_len = input_ids.shape + device = input_ids.device - # Create bidirectional attention mask to override default causal masking - # Shape: [batch_size, 1, seq_len, seq_len] - bidirectional_mask = torch.ones( - seq_len, seq_len, dtype=torch.bool, device=input_ids.device - ) - bidirectional_mask = ( - bidirectional_mask.unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, 1, seq_len, seq_len) - ) - - # Apply padding mask if provided - if attention_mask is not None: - # Convert attention_mask to 4D and apply - expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2) - expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len) - - bidirectional_mask = ( - bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2) + if attention_mask is None or not self.config.sample_packing: + # Simple case: no attention mask, allow all-to-all attention + return torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device ) + # Create attention mask by comparing sample IDs element-wise + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + + # Add head dimension: [batch_size, 1, seq_len, seq_len] + bidirectional_mask = bidirectional_mask.unsqueeze(1) + return bidirectional_mask def compute_diffusion_loss( self, - model: PreTrainedModel, + model: nn.Module, input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, float]]: + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, dict[str, float]]: """ Compute diffusion loss. @@ -139,7 +156,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors input_ids, attention_mask, self.config.eps ) - # Create bidirectional attention mask (always required for diffusion training) + # Create bidirectional attention mask bidirectional_mask = self.create_bidirectional_attention_mask( input_ids, attention_mask ) @@ -151,14 +168,8 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors ) logits = outputs.logits - # Apply attention mask to masked_indices if provided - if attention_mask is not None: - loss_mask = masked_indices & attention_mask.bool() - else: - loss_mask = masked_indices - - if loss_mask.sum() > 0: - valid_indices = torch.where(loss_mask) + if masked_indices.sum() > 0: + valid_indices = torch.where(masked_indices) batch_indices, seq_indices = valid_indices # Extract the relevant data @@ -200,29 +211,23 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors "loss": loss.item(), "accuracy": accuracy.item(), "mask_ratio": masked_indices.float().mean().item(), - "num_masked_tokens": loss_mask.sum().item(), - "avg_p_mask": ( - p_mask[masked_indices].mean().item() - if masked_indices.sum() > 0 - else 0.0 - ), - "ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0, + "num_masked_tokens": masked_indices.sum().item(), + "avg_p_mask": p_mask[masked_indices].mean().item(), + "ce_loss": ce_loss.item(), } if self.config.importance_weighting: - metrics["importance_weight_avg"] = ( - (1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0 - ) + metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() return loss, metrics def compute_loss( self, - model: PreTrainedModel, - inputs: Dict[str, torch.Tensor], + model: nn.Module, + inputs: dict[str, torch.Tensor], return_outputs: bool = False, - num_items_in_batch: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: """Override compute_loss to use diffusion loss.""" input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") @@ -232,10 +237,10 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask) - # Log metrics - if self.state.is_local_process_zero: - for key, value in metrics.items(): - self.log({f"train/diffusion_{key}": value}) + # # Log metrics + # if self.state.is_local_process_zero: + # for key, value in metrics.items(): + # self.log({f"train/diffusion_{key}": value}) if return_outputs: # TODO: compute outputs (?)