diffusion training plugin

This commit is contained in:
Dan Saunders
2025-08-14 01:48:22 -04:00
parent 09145de8fa
commit 3156c605d4
8 changed files with 578 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Dataset configuration for pretraining
datasets:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
val_set_size: 0.001
plugins:
- diffusion.DiffusionPlugin
noise_schedule: "cosine"
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 2000
eps: 5e-4
importance_weighting: true
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 500
save_strategy: steps
eval_strategy: steps
save_steps: 1000
eval_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,55 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- diffusion.DiffusionPlugin
noise_schedule: "linear"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 1000
eps: 1e-3
importance_weighting: true
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
save_strategy: steps
eval_strategy: steps
save_steps: 500
eval_steps: 500
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -385,10 +385,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
# Check if trainer class inherits from transformers.Trainer
# If so, we should pass the tokenizer/processing_class even if not in direct signature
from transformers import Trainer as HFTrainer
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
elif issubclass(trainer_cls, HFTrainer):
# For subclasses of transformers.Trainer, try processing_class first (newer HF versions)
trainer_kwargs["processing_class"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None

View File

@@ -0,0 +1,117 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using the LLaDA (Large Language
And Diffusion Assistant) approach within the Axolotl framework.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to see the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
your training config.
## Quickstart
### Basic Configuration
Add the following to your Axolotl configuration YAML:
```yaml
# Enable diffusion LM training plugin
plugins:
- diffusion.DiffusionPlugin
# Diffusion-specific configuration
noise_schedule: "linear" # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 1000
eps: 1e-3
importance_weighting: true
# Model configuration
base_model: meta-llama/Llama-3.2-1B
model_type: llama
# Standard Axolotl configuration
datasets:
- path: your_dataset
type: completion # or conversation
sequence_len: 1024
micro_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 3e-4
```
## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
## How It Works
### Random Masking
During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Bidirectional Attention
The plugin uses native 4D attention masks to:
- Enable bidirectional attention without patches
- Allow all tokens to attend to all other tokens
- Maintain proper padding masks
- Work with modern `transformers` models out of the box
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Performance Tips
### Memory Optimization
- Bidirectional attention uses more memory than causal attention
- Consider reducing `micro_batch_size` if you encounter OOM errors
- Consider using gradient checkpointing, torch.compile,
### Training Stability
- Start with `noise_schedule: "linear"` for more predictable behavior
- Enable `importance_weighting` for better gradient scaling
### Convergence
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
- Expect different loss curves compared to standard language modeling
## Metrics and Monitoring
The plugin adds several metrics to track diffusion training:
- `train/diffusion_loss`: Weighted diffusion loss
- `train/diffusion_accuracy`: Accuracy on masked tokens
- `train/diffusion_mask_ratio`: Average fraction of tokens masked
- `train/diffusion_num_masked_tokens`: Number of tokens masked
- `train/diffusion_avg_p_mask`: Average masking probability
- `train/diffusion_ce_loss`: Unweighted cross-entropy loss
- `train/diffusion_importance_weight_avg`: Average importance weight
## Limitations
- No flash attention support
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)

View File

@@ -0,0 +1,10 @@
"""
Diffusion LM training plugin for Axolotl.
This plugin enables diffusion language model training using the LLaDA approach.
"""
from .args import DiffusionArgs
from .plugin import DiffusionPlugin
__all__ = ["DiffusionArgs", "DiffusionPlugin"]

View File

@@ -0,0 +1,43 @@
"""Configuration arguments for diffusion LM training."""
from typing import Literal
from pydantic import BaseModel, Field
class DiffusionArgs(BaseModel):
"""Arguments for diffusion LM training plugin."""
# Noise schedule configuration
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=1000, ge=1, description="Number of diffusion timesteps"
)
# Forward process parameters
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training configuration
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)

View File

@@ -0,0 +1,40 @@
"""Diffusion LM training plugin for Axolotl."""
from transformers import PreTrainedModel, Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Perform actions after model is loaded."""
self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Return custom trainer class for diffusion training."""
return DiffusionTrainer
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Configure trainer after creation."""
trainer.set_config(cfg)

View File

@@ -0,0 +1,245 @@
"""Custom trainer for diffusion LM training."""
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
def forward_process(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
eps: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create random mask based on probability
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# Get tokenizer
tokenizer = self.processing_class
assert tokenizer is not None, "Tokenizer not available on Trainer object."
# Get mask token ID
mask_token_id = getattr(tokenizer, "mask_token_id", None)
if mask_token_id is None:
mask_token_id = getattr(tokenizer, "unk_token_id", None)
# Create masked input using configured mask token
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
# Create bidirectional attention mask to override default causal masking
# Shape: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = torch.ones(
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
)
bidirectional_mask = (
bidirectional_mask.unsqueeze(0)
.unsqueeze(0)
.expand(batch_size, 1, seq_len, seq_len)
)
# Apply padding mask if provided
if attention_mask is not None:
# Convert attention_mask to 4D and apply
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
bidirectional_mask = (
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
)
return bidirectional_mask
def compute_diffusion_loss(
self,
model: PreTrainedModel,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self.forward_process(
input_ids, attention_mask, self.config.eps
)
# Create bidirectional attention mask (always required for diffusion training)
bidirectional_mask = self.create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass
outputs = model(
input_ids=noisy_batch,
attention_mask=bidirectional_mask,
)
logits = outputs.logits
# Apply attention mask to masked_indices if provided
if attention_mask is not None:
loss_mask = masked_indices & attention_mask.bool()
else:
loss_mask = masked_indices
if loss_mask.sum() > 0:
valid_indices = torch.where(loss_mask)
batch_indices, seq_indices = valid_indices
# Extract the relevant data
masked_logits = logits[
batch_indices, seq_indices
] # [num_masked_tokens, vocab_size]
masked_targets = input_ids[
batch_indices, seq_indices
] # [num_masked_tokens]
masked_p_mask = p_mask[batch_indices, seq_indices] # [num_masked_tokens]
# Compute cross-entropy loss without reduction (cast to fp32 for stability)
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
# Apply importance weighting if enabled
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize by total tokens
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": loss_mask.sum().item(),
"avg_p_mask": (
p_mask[masked_indices].mean().item()
if masked_indices.sum() > 0
else 0.0
),
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
}
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
)
return loss, metrics
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
# Log metrics
if self.state.is_local_process_zero:
for key, value in metrics.items():
self.log({f"train/diffusion_{key}": value})
if return_outputs:
# TODO: compute outputs (?)
outputs = [loss]
return (loss, outputs)
return loss