From 3156c605d4663bdb9580f470e719eba1b6b9f6ae Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 14 Aug 2025 01:48:22 -0400 Subject: [PATCH] diffusion training plugin --- .../llama-3/diffusion-3.2-1b-pretrain.yaml | 60 +++++ examples/llama-3/diffusion-3.2-1b-sft.yaml | 55 ++++ src/axolotl/core/builders/causal.py | 8 + src/axolotl/integrations/diffusion/README.md | 117 +++++++++ .../integrations/diffusion/__init__.py | 10 + src/axolotl/integrations/diffusion/args.py | 43 +++ src/axolotl/integrations/diffusion/plugin.py | 40 +++ src/axolotl/integrations/diffusion/trainer.py | 245 ++++++++++++++++++ 8 files changed, 578 insertions(+) create mode 100644 examples/llama-3/diffusion-3.2-1b-pretrain.yaml create mode 100644 examples/llama-3/diffusion-3.2-1b-sft.yaml create mode 100644 src/axolotl/integrations/diffusion/README.md create mode 100644 src/axolotl/integrations/diffusion/__init__.py create mode 100644 src/axolotl/integrations/diffusion/args.py create mode 100644 src/axolotl/integrations/diffusion/plugin.py create mode 100644 src/axolotl/integrations/diffusion/trainer.py diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml new file mode 100644 index 000000000..7084216bb --- /dev/null +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -0,0 +1,60 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Dataset configuration for pretraining +datasets: + - path: wikitext + name: wikitext-103-raw-v1 + type: completion + field: text +val_set_size: 0.001 + +plugins: + - diffusion.DiffusionPlugin +noise_schedule: "cosine" +min_mask_ratio: 0.15 +max_mask_ratio: 0.85 +num_diffusion_steps: 2000 +eps: 5e-4 +importance_weighting: true + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true +eval_sample_packing: true + +gradient_accumulation_steps: 8 +micro_batch_size: 4 +max_steps: 10000 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 3e-4 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +sdp_attention: true + +warmup_steps: 500 + +save_strategy: steps +eval_strategy: steps +save_steps: 1000 +eval_steps: 1000 + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/diffusion-3.2-1b-sft.yaml b/examples/llama-3/diffusion-3.2-1b-sft.yaml new file mode 100644 index 000000000..30c2504b4 --- /dev/null +++ b/examples/llama-3/diffusion-3.2-1b-sft.yaml @@ -0,0 +1,55 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +val_set_size: 0.05 + +plugins: + - diffusion.DiffusionPlugin +noise_schedule: "linear" +min_mask_ratio: 0.1 +max_mask_ratio: 0.9 +num_diffusion_steps: 1000 +eps: 1e-3 +importance_weighting: true + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true +eval_sample_packing: true + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 1e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +sdp_attention: true + +save_strategy: steps +eval_strategy: steps +save_steps: 500 +eval_steps: 500 + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc21762..07bea1237 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -385,10 +385,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **data_collator_kwargs, ) sig = inspect.signature(trainer_cls) + + # Check if trainer class inherits from transformers.Trainer + # If so, we should pass the tokenizer/processing_class even if not in direct signature + from transformers import Trainer as HFTrainer + if "processing_class" in sig.parameters: trainer_kwargs["processing_class"] = self.tokenizer elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer + elif issubclass(trainer_cls, HFTrainer): + # For subclasses of transformers.Trainer, try processing_class first (newer HF versions) + trainer_kwargs["processing_class"] = self.tokenizer if ( trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md new file mode 100644 index 000000000..ce2b0c8f7 --- /dev/null +++ b/src/axolotl/integrations/diffusion/README.md @@ -0,0 +1,117 @@ +# Diffusion LM Training Plugin for Axolotl + +This plugin enables diffusion language model training using the LLaDA (Large Language +And Diffusion Assistant) approach within the Axolotl framework. + +## Overview + +LLaDA is a diffusion-based approach to language model training that uses: +- **Random token masking** during training instead of next-token prediction +- **Bidirectional attention** to allow the model to see the full context +- **Importance weighting** based on masking probabilities for stable training + +This approach can lead to more robust language models with better understanding of +bidirectional context. + +## Installation + +The plugin is included with Axolotl. To use it, simply add the plugin configuration to +your training config. + +## Quickstart + +### Basic Configuration + +Add the following to your Axolotl configuration YAML: + +```yaml +# Enable diffusion LM training plugin +plugins: + - diffusion.DiffusionPlugin + +# Diffusion-specific configuration +noise_schedule: "linear" # or "cosine" +min_mask_ratio: 0.1 +max_mask_ratio: 0.9 +num_diffusion_steps: 1000 +eps: 1e-3 +importance_weighting: true + +# Model configuration +base_model: meta-llama/Llama-3.2-1B +model_type: llama + +# Standard Axolotl configuration +datasets: + - path: your_dataset + type: completion # or conversation + +sequence_len: 1024 +micro_batch_size: 8 +gradient_accumulation_steps: 4 +learning_rate: 3e-4 +``` + +## Supported Models + +Any models that support 4D attention masks should work out of the box. If not, please +create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)! + +## How It Works + +### Random Masking +During training, tokens are randomly masked based on a sampled timestep: +- Sample timestep `t` uniformly from [0, 1] +- Calculate masking probability: `p = (1 - eps) * t + eps` +- Randomly mask tokens with probability `p` + +### Bidirectional Attention +The plugin uses native 4D attention masks to: +- Enable bidirectional attention without patches +- Allow all tokens to attend to all other tokens +- Maintain proper padding masks +- Work with modern `transformers` models out of the box + +### Diffusion Loss + +Loss is computed only on masked tokens with (optional) importance weighting: + +``` +loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens +``` + +## Performance Tips + +### Memory Optimization +- Bidirectional attention uses more memory than causal attention +- Consider reducing `micro_batch_size` if you encounter OOM errors +- Consider using gradient checkpointing, torch.compile, + +### Training Stability +- Start with `noise_schedule: "linear"` for more predictable behavior +- Enable `importance_weighting` for better gradient scaling + +### Convergence +- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics +- Expect different loss curves compared to standard language modeling + +## Metrics and Monitoring + +The plugin adds several metrics to track diffusion training: + +- `train/diffusion_loss`: Weighted diffusion loss +- `train/diffusion_accuracy`: Accuracy on masked tokens +- `train/diffusion_mask_ratio`: Average fraction of tokens masked +- `train/diffusion_num_masked_tokens`: Number of tokens masked +- `train/diffusion_avg_p_mask`: Average masking probability +- `train/diffusion_ce_loss`: Unweighted cross-entropy loss +- `train/diffusion_importance_weight_avg`: Average importance weight + +## Limitations + +- No flash attention support + +## References + +- [LLaDA Paper](https://arxiv.org/abs/2404.10406) +- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl) diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py new file mode 100644 index 000000000..84ae75ccb --- /dev/null +++ b/src/axolotl/integrations/diffusion/__init__.py @@ -0,0 +1,10 @@ +""" +Diffusion LM training plugin for Axolotl. + +This plugin enables diffusion language model training using the LLaDA approach. +""" + +from .args import DiffusionArgs +from .plugin import DiffusionPlugin + +__all__ = ["DiffusionArgs", "DiffusionPlugin"] diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py new file mode 100644 index 000000000..949b08824 --- /dev/null +++ b/src/axolotl/integrations/diffusion/args.py @@ -0,0 +1,43 @@ +"""Configuration arguments for diffusion LM training.""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class DiffusionArgs(BaseModel): + """Arguments for diffusion LM training plugin.""" + + # Noise schedule configuration + noise_schedule: Literal["linear", "cosine"] = Field( + default="linear", description="Type of noise schedule for diffusion training" + ) + min_mask_ratio: float = Field( + default=0.1, + ge=0.0, + le=1.0, + description="Minimum masking ratio for diffusion noise schedule", + ) + max_mask_ratio: float = Field( + default=0.9, + ge=0.0, + le=1.0, + description="Maximum masking ratio for diffusion noise schedule", + ) + num_diffusion_steps: int = Field( + default=1000, ge=1, description="Number of diffusion timesteps" + ) + + # Forward process parameters + eps: float = Field( + default=1e-3, + ge=0.0, + le=1.0, + description="Epsilon value for minimum masking probability in forward process", + ) + + # Training configuration + importance_weighting: bool = Field( + default=True, + description="Apply importance weighting to loss based on masking probability", + ) diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py new file mode 100644 index 000000000..76eea006d --- /dev/null +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -0,0 +1,40 @@ +"""Diffusion LM training plugin for Axolotl.""" + +from transformers import PreTrainedModel, Trainer + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .trainer import DiffusionTrainer + +LOG = get_logger(__name__) + + +class DiffusionPlugin(BasePlugin): + """ + Plugin for diffusion language model training. + + This plugin enables diffusion-based training using the LLaDA approach, which uses + random masking and bidirectional attention to train language models. + """ + + def __init__(self): + super().__init__() + self.cfg = None + + def get_input_args(self) -> str: + """Returns the pydantic model for LLaDA plugin arguments.""" + return "axolotl.integrations.diffusion.DiffusionArgs" + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel): + """Perform actions after model is loaded.""" + self.cfg = cfg + + def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + """Return custom trainer class for diffusion training.""" + return DiffusionTrainer + + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Configure trainer after creation.""" + trainer.set_config(cfg) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py new file mode 100644 index 000000000..bb178341f --- /dev/null +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -0,0 +1,245 @@ +"""Custom trainer for diffusion LM training.""" + +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import PreTrainedModel + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors + """Custom trainer for diffusion LM training that overrides loss computation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config = None + + def set_config(self, config: DictDefault): + """Set config for diffusion training.""" + self.config = config + + def forward_process( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + eps: float = 1e-3, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward noising process. A timestep is sampled along the process, and tokens are + masked with probability determined by the configured noise schedule. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + eps: Small epsilon value for minimum masking probability. + + Returns: + noisy_batch: Input with some tokens masked. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Sample random timesteps for each sample in batch + t = torch.rand(batch_size, device=device) + + # Calculate masking probability with epsilon + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens if attention_mask is provided + if attention_mask is not None: + valid_mask = attention_mask.bool() + p_mask = p_mask * valid_mask.float() + + # Create random mask based on probability + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # Get tokenizer + tokenizer = self.processing_class + assert tokenizer is not None, "Tokenizer not available on Trainer object." + + # Get mask token ID + mask_token_id = getattr(tokenizer, "mask_token_id", None) + if mask_token_id is None: + mask_token_id = getattr(tokenizer, "unk_token_id", None) + + # Create masked input using configured mask token + noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) + + return noisy_batch, masked_indices, p_mask + + def create_bidirectional_attention_mask( + self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Create bidirectional attention mask to override default causal masking. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. + """ + batch_size, seq_len = input_ids.shape + + # Create bidirectional attention mask to override default causal masking + # Shape: [batch_size, 1, seq_len, seq_len] + bidirectional_mask = torch.ones( + seq_len, seq_len, dtype=torch.bool, device=input_ids.device + ) + bidirectional_mask = ( + bidirectional_mask.unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, 1, seq_len, seq_len) + ) + + # Apply padding mask if provided + if attention_mask is not None: + # Convert attention_mask to 4D and apply + expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2) + expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len) + + bidirectional_mask = ( + bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2) + ) + + return bidirectional_mask + + def compute_diffusion_loss( + self, + model: PreTrainedModel, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, float]]: + """ + Compute diffusion loss. + + Args: + model: The model to compute loss for. + input_ids: Ground truth token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + + Returns: + loss: Cross-entropy loss. + metrics: Dictionary of metrics. + """ + # Apply forward process + noisy_batch, masked_indices, p_mask = self.forward_process( + input_ids, attention_mask, self.config.eps + ) + + # Create bidirectional attention mask (always required for diffusion training) + bidirectional_mask = self.create_bidirectional_attention_mask( + input_ids, attention_mask + ) + + # Forward pass + outputs = model( + input_ids=noisy_batch, + attention_mask=bidirectional_mask, + ) + logits = outputs.logits + + # Apply attention mask to masked_indices if provided + if attention_mask is not None: + loss_mask = masked_indices & attention_mask.bool() + else: + loss_mask = masked_indices + + if loss_mask.sum() > 0: + valid_indices = torch.where(loss_mask) + batch_indices, seq_indices = valid_indices + + # Extract the relevant data + masked_logits = logits[ + batch_indices, seq_indices + ] # [num_masked_tokens, vocab_size] + masked_targets = input_ids[ + batch_indices, seq_indices + ] # [num_masked_tokens] + masked_p_mask = p_mask[batch_indices, seq_indices] # [num_masked_tokens] + + # Compute cross-entropy loss without reduction (cast to fp32 for stability) + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + # Apply importance weighting if enabled + if self.config.importance_weighting: + masked_p_mask = masked_p_mask.float() + weighted_loss = token_loss / masked_p_mask + else: + weighted_loss = token_loss + + # Final loss: sum weighted losses, normalize by total tokens + loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) + ce_loss = token_loss.mean() + + # Compute accuracy on masked tokens + with torch.no_grad(): + pred_tokens = masked_logits.argmax(dim=-1) + accuracy = (pred_tokens == masked_targets).float().mean() + else: + loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + accuracy = torch.tensor(0.0, device=input_ids.device) + ce_loss = torch.tensor(0.0, device=input_ids.device) + masked_p_mask = torch.tensor(1.0, device=input_ids.device) + + metrics = { + "loss": loss.item(), + "accuracy": accuracy.item(), + "mask_ratio": masked_indices.float().mean().item(), + "num_masked_tokens": loss_mask.sum().item(), + "avg_p_mask": ( + p_mask[masked_indices].mean().item() + if masked_indices.sum() > 0 + else 0.0 + ), + "ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0, + } + + if self.config.importance_weighting: + metrics["importance_weight_avg"] = ( + (1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0 + ) + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel, + inputs: Dict[str, torch.Tensor], + return_outputs: bool = False, + num_items_in_batch: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """Override compute_loss to use diffusion loss.""" + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + + if input_ids is None: + raise ValueError("input_ids is required for diffusion training") + + loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask) + + # Log metrics + if self.state.is_local_process_zero: + for key, value in metrics.items(): + self.log({f"train/diffusion_{key}": value}) + + if return_outputs: + # TODO: compute outputs (?) + outputs = [loss] + return (loss, outputs) + + return loss