diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md index b4176dd60..f79d5a46b 100644 --- a/src/axolotl/integrations/diffusion/README.md +++ b/src/axolotl/integrations/diffusion/README.md @@ -53,8 +53,9 @@ model_type: llama # Standard Axolotl configuration datasets: - path: your_dataset - type: completion # or conversation + ... +# Other config sequence_len: 1024 micro_batch_size: 8 gradient_accumulation_steps: 4 @@ -85,31 +86,16 @@ The plugin uses native 4D attention masks to: Loss is computed only on masked tokens with (optional) importance weighting: -``` +```python loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens ``` -## Performance Tips - -### Memory Optimization -- Bidirectional attention uses more memory than causal attention -- Consider reducing `micro_batch_size` if you encounter OOM errors -- Consider using gradient checkpointing, torch.compile, - -### Training Stability -- Start with `noise_schedule: linear` for more predictable behavior -- Enable `importance_weighting: true` for better gradient scaling - -### Convergence -- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics -- Expect different loss curves compared to standard language modeling - ## Sample Generation When `generate_samples: true`, the plugin generates samples during training: ``` -📝 Sample 1: +Sample 1: Original (45 tokens): The quick brown fox jumps over the lazy dog... Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]... Generated: The quick brown fox jumps over the lazy dog... @@ -136,4 +122,4 @@ The plugin adds several metrics to track diffusion training: ## References - [LLaDA Paper](https://arxiv.org/abs/2404.10406) -- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl) +- [Axolotl Documentation](https://docs.axolotl.ai/)