This commit is contained in:
Dan Saunders
2025-08-18 19:17:24 +00:00
parent b210db2d15
commit 63d2280999

View File

@@ -53,8 +53,9 @@ model_type: llama
# Standard Axolotl configuration
datasets:
- path: your_dataset
type: completion # or conversation
...
# Other config
sequence_len: 1024
micro_batch_size: 8
gradient_accumulation_steps: 4
@@ -85,31 +86,16 @@ The plugin uses native 4D attention masks to:
Loss is computed only on masked tokens with (optional) importance weighting:
```
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Performance Tips
### Memory Optimization
- Bidirectional attention uses more memory than causal attention
- Consider reducing `micro_batch_size` if you encounter OOM errors
- Consider using gradient checkpointing, torch.compile,
### Training Stability
- Start with `noise_schedule: linear` for more predictable behavior
- Enable `importance_weighting: true` for better gradient scaling
### Convergence
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
- Expect different loss curves compared to standard language modeling
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
```
📝 Sample 1:
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
@@ -136,4 +122,4 @@ The plugin adds several metrics to track diffusion training:
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)
- [Axolotl Documentation](https://docs.axolotl.ai/)