nits
This commit is contained in:
@@ -53,8 +53,9 @@ model_type: llama
|
||||
# Standard Axolotl configuration
|
||||
datasets:
|
||||
- path: your_dataset
|
||||
type: completion # or conversation
|
||||
...
|
||||
|
||||
# Other config
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 8
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -85,31 +86,16 @@ The plugin uses native 4D attention masks to:
|
||||
|
||||
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||
|
||||
```
|
||||
```python
|
||||
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
### Memory Optimization
|
||||
- Bidirectional attention uses more memory than causal attention
|
||||
- Consider reducing `micro_batch_size` if you encounter OOM errors
|
||||
- Consider using gradient checkpointing, torch.compile,
|
||||
|
||||
### Training Stability
|
||||
- Start with `noise_schedule: linear` for more predictable behavior
|
||||
- Enable `importance_weighting: true` for better gradient scaling
|
||||
|
||||
### Convergence
|
||||
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
|
||||
- Expect different loss curves compared to standard language modeling
|
||||
|
||||
## Sample Generation
|
||||
|
||||
When `generate_samples: true`, the plugin generates samples during training:
|
||||
|
||||
```
|
||||
📝 Sample 1:
|
||||
Sample 1:
|
||||
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
||||
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
||||
Generated: The quick brown fox jumps over the lazy dog...
|
||||
@@ -136,4 +122,4 @@ The plugin adds several metrics to track diffusion training:
|
||||
## References
|
||||
|
||||
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
- [Axolotl Documentation](https://docs.axolotl.ai/)
|
||||
|
||||
Reference in New Issue
Block a user