nits
This commit is contained in:
@@ -53,8 +53,9 @@ model_type: llama
|
|||||||
# Standard Axolotl configuration
|
# Standard Axolotl configuration
|
||||||
datasets:
|
datasets:
|
||||||
- path: your_dataset
|
- path: your_dataset
|
||||||
type: completion # or conversation
|
...
|
||||||
|
|
||||||
|
# Other config
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
micro_batch_size: 8
|
micro_batch_size: 8
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -85,31 +86,16 @@ The plugin uses native 4D attention masks to:
|
|||||||
|
|
||||||
Loss is computed only on masked tokens with (optional) importance weighting:
|
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||||
|
|
||||||
```
|
```python
|
||||||
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||||
```
|
```
|
||||||
|
|
||||||
## Performance Tips
|
|
||||||
|
|
||||||
### Memory Optimization
|
|
||||||
- Bidirectional attention uses more memory than causal attention
|
|
||||||
- Consider reducing `micro_batch_size` if you encounter OOM errors
|
|
||||||
- Consider using gradient checkpointing, torch.compile,
|
|
||||||
|
|
||||||
### Training Stability
|
|
||||||
- Start with `noise_schedule: linear` for more predictable behavior
|
|
||||||
- Enable `importance_weighting: true` for better gradient scaling
|
|
||||||
|
|
||||||
### Convergence
|
|
||||||
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
|
|
||||||
- Expect different loss curves compared to standard language modeling
|
|
||||||
|
|
||||||
## Sample Generation
|
## Sample Generation
|
||||||
|
|
||||||
When `generate_samples: true`, the plugin generates samples during training:
|
When `generate_samples: true`, the plugin generates samples during training:
|
||||||
|
|
||||||
```
|
```
|
||||||
📝 Sample 1:
|
Sample 1:
|
||||||
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
||||||
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
||||||
Generated: The quick brown fox jumps over the lazy dog...
|
Generated: The quick brown fox jumps over the lazy dog...
|
||||||
@@ -136,4 +122,4 @@ The plugin adds several metrics to track diffusion training:
|
|||||||
## References
|
## References
|
||||||
|
|
||||||
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||||
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)
|
- [Axolotl Documentation](https://docs.axolotl.ai/)
|
||||||
|
|||||||
Reference in New Issue
Block a user