Files
axolotl/src/axolotl/integrations/diffusion
Dan Saunders 1b53c49e1a text diffusion training plugin (#3067)
* diffusion training plugin

* cleanup

* nits

* fixes + improvements

* add back in reinit_weights (clobbered?); masking / pretrain fixes

* nits

* cleanup; tests draft

* sample generation, tests fixes

* fixes

* nits

* add inference support; add auto-mask token support

* nits

* nits

* progress

* simplify logging

* lint

* prefix args with diffusion_

* coderabbito

* tests fix

* nit

* nits

* cleanup + nits

* nits

* fix SFT sample gen

* fixes

* fix

* comments

* comments

* lint

* reward model lora fix

* cleanup; fix pretraining_dataset case

* gradio inference

* update cfgs

* update cfgs

* train, generation parity, cleanup

* fix

* simplify

* test

* test fix
2025-09-10 20:27:00 -04:00
..
2025-09-10 20:27:00 -04:00

Diffusion LM Training Plugin for Axolotl

This plugin enables diffusion language model training using an approach inspired by LLaDA (Large Language Diffusion Models) within Axolotl.

Overview

LLaDA is a diffusion-based approach to language model training that uses:

  • Random token masking during training instead of next-token prediction
  • Bidirectional attention to allow the model to attend to the full context
  • Importance weighting based on masking probabilities for stable training

This approach can lead to more robust language models with better understanding of bidirectional context.

Installation

The plugin is included with Axolotl. See our installation docs.

Quickstart

Train with an example config (Llama3.2 1B):

  • Pretrain: axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml
  • SFT: axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml

Basic Configuration

You can also modify your existing configs to enable / customize diffusion training.

Add the following to your Axolotl config:

# Enable diffusion LM training plugin
plugins:
  - axolotl.integrations.diffusion.DiffusionPlugin

And, configure the nested diffusion block (defaults shown):

diffusion:
  noise_schedule: linear  # or "cosine"
  min_mask_ratio: 0.1
  max_mask_ratio: 0.9
  num_diffusion_steps: 128
  eps: 1e-3
  importance_weighting: true

  # Mask token (training auto-adds if missing, avoid pad/eos)
  mask_token_str: "<|diffusion_mask|>"
  # Or use an existing special token id (e.g., 128002 for Llama-3.x)
  # mask_token_id: 128002

  # Sample generation during training (optional)
  generate_samples: true
  generation_interval: 100
  num_generation_samples: 3
  generation_steps: 128
  generation_temperature: 0.0
  generation_max_length: 100

Supported Models

Any models that support 4D attention masks should work out of the box. If not, please create an issue or open a PR!

How It Works

Random Masking

During training, tokens are randomly masked:

  • Sample timestep t uniformly from [0, 1]
  • Calculate masking probability: p = (1 - eps) * t + eps
  • Randomly mask tokens with probability p

Diffusion Loss

Loss is computed only on masked tokens with (optional) importance weighting:

loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens

Sample Generation

When diffusion.generate_samples: true, the plugin generates samples during training:

Sample 1:
   Original (45 tokens): The quick brown fox jumps over the lazy dog...
   Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
   Generated: The quick brown fox jumps over the lazy dog...

Samples are logged to console and wandb (if enabled).

Inference

Diffusion inference is integrated into the standard Axolotl CLI. Use the same config you trained with and run:

axolotl inference path/to/your-config.yaml

Optionally, pass --gradio to use a simple web interface.

Interactive controls (prefix the prompt with commands):

  • :complete N → completion mode with N new masked tokens appended (default 64)
  • :mask R → random masking mode with target mask ratio R in [0.0, 1.0]

Example session:

================================================================================
Commands:
:complete N -> completion mode with N tokens (default 64)
:mask R     -> random masking with ratio R (0.01.0)
================================================================================
Give me an instruction (Ctrl + D to submit):

:mask 0.4 The quick brown fox jumps over the lazy dog

Masked (40.0%):
The [MASK] brown [MASK] jumps over the [MASK] dog

Generated:
The quick brown fox jumps over the loud dog

Metrics and Monitoring

The plugin adds (or modifies) several metrics to track diffusion training:

  • train/loss: Weighted diffusion loss
  • train/accuracy: Accuracy on masked tokens
  • train/mask_ratio: Average fraction of tokens masked
  • train/num_masked_tokens: Number of tokens masked
  • train/avg_p_mask: Average masking probability
  • train/ce_loss: Unweighted cross-entropy loss
  • train/importance_weight_avg: Average importance weight

Limitations

  • No flash attention support
  • No RL training support

References