diffusion training plugin
This commit is contained in:
60
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
60
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# Dataset configuration for pretraining
|
||||||
|
datasets:
|
||||||
|
- path: wikitext
|
||||||
|
name: wikitext-103-raw-v1
|
||||||
|
type: completion
|
||||||
|
field: text
|
||||||
|
val_set_size: 0.001
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- diffusion.DiffusionPlugin
|
||||||
|
noise_schedule: "cosine"
|
||||||
|
min_mask_ratio: 0.15
|
||||||
|
max_mask_ratio: 0.85
|
||||||
|
num_diffusion_steps: 2000
|
||||||
|
eps: 5e-4
|
||||||
|
importance_weighting: true
|
||||||
|
|
||||||
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 4
|
||||||
|
max_steps: 10000
|
||||||
|
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 3e-4
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 500
|
||||||
|
|
||||||
|
save_strategy: steps
|
||||||
|
eval_strategy: steps
|
||||||
|
save_steps: 1000
|
||||||
|
eval_steps: 1000
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
55
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
55
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
|
type: alpaca
|
||||||
|
val_set_size: 0.05
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- diffusion.DiffusionPlugin
|
||||||
|
noise_schedule: "linear"
|
||||||
|
min_mask_ratio: 0.1
|
||||||
|
max_mask_ratio: 0.9
|
||||||
|
num_diffusion_steps: 1000
|
||||||
|
eps: 1e-3
|
||||||
|
importance_weighting: true
|
||||||
|
|
||||||
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 1e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
save_strategy: steps
|
||||||
|
eval_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
eval_steps: 500
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -385,10 +385,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
)
|
)
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
|
|
||||||
|
# Check if trainer class inherits from transformers.Trainer
|
||||||
|
# If so, we should pass the tokenizer/processing_class even if not in direct signature
|
||||||
|
from transformers import Trainer as HFTrainer
|
||||||
|
|
||||||
if "processing_class" in sig.parameters:
|
if "processing_class" in sig.parameters:
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
elif "tokenizer" in sig.parameters:
|
elif "tokenizer" in sig.parameters:
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
elif issubclass(trainer_cls, HFTrainer):
|
||||||
|
# For subclasses of transformers.Trainer, try processing_class first (newer HF versions)
|
||||||
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
if (
|
if (
|
||||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||||
and self.cfg.datasets is not None
|
and self.cfg.datasets is not None
|
||||||
|
|||||||
117
src/axolotl/integrations/diffusion/README.md
Normal file
117
src/axolotl/integrations/diffusion/README.md
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# Diffusion LM Training Plugin for Axolotl
|
||||||
|
|
||||||
|
This plugin enables diffusion language model training using the LLaDA (Large Language
|
||||||
|
And Diffusion Assistant) approach within the Axolotl framework.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LLaDA is a diffusion-based approach to language model training that uses:
|
||||||
|
- **Random token masking** during training instead of next-token prediction
|
||||||
|
- **Bidirectional attention** to allow the model to see the full context
|
||||||
|
- **Importance weighting** based on masking probabilities for stable training
|
||||||
|
|
||||||
|
This approach can lead to more robust language models with better understanding of
|
||||||
|
bidirectional context.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
|
||||||
|
your training config.
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
Add the following to your Axolotl configuration YAML:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Enable diffusion LM training plugin
|
||||||
|
plugins:
|
||||||
|
- diffusion.DiffusionPlugin
|
||||||
|
|
||||||
|
# Diffusion-specific configuration
|
||||||
|
noise_schedule: "linear" # or "cosine"
|
||||||
|
min_mask_ratio: 0.1
|
||||||
|
max_mask_ratio: 0.9
|
||||||
|
num_diffusion_steps: 1000
|
||||||
|
eps: 1e-3
|
||||||
|
importance_weighting: true
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
model_type: llama
|
||||||
|
|
||||||
|
# Standard Axolotl configuration
|
||||||
|
datasets:
|
||||||
|
- path: your_dataset
|
||||||
|
type: completion # or conversation
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
micro_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
learning_rate: 3e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
Any models that support 4D attention masks should work out of the box. If not, please
|
||||||
|
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Random Masking
|
||||||
|
During training, tokens are randomly masked based on a sampled timestep:
|
||||||
|
- Sample timestep `t` uniformly from [0, 1]
|
||||||
|
- Calculate masking probability: `p = (1 - eps) * t + eps`
|
||||||
|
- Randomly mask tokens with probability `p`
|
||||||
|
|
||||||
|
### Bidirectional Attention
|
||||||
|
The plugin uses native 4D attention masks to:
|
||||||
|
- Enable bidirectional attention without patches
|
||||||
|
- Allow all tokens to attend to all other tokens
|
||||||
|
- Maintain proper padding masks
|
||||||
|
- Work with modern `transformers` models out of the box
|
||||||
|
|
||||||
|
### Diffusion Loss
|
||||||
|
|
||||||
|
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||||
|
|
||||||
|
```
|
||||||
|
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
### Memory Optimization
|
||||||
|
- Bidirectional attention uses more memory than causal attention
|
||||||
|
- Consider reducing `micro_batch_size` if you encounter OOM errors
|
||||||
|
- Consider using gradient checkpointing, torch.compile,
|
||||||
|
|
||||||
|
### Training Stability
|
||||||
|
- Start with `noise_schedule: "linear"` for more predictable behavior
|
||||||
|
- Enable `importance_weighting` for better gradient scaling
|
||||||
|
|
||||||
|
### Convergence
|
||||||
|
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
|
||||||
|
- Expect different loss curves compared to standard language modeling
|
||||||
|
|
||||||
|
## Metrics and Monitoring
|
||||||
|
|
||||||
|
The plugin adds several metrics to track diffusion training:
|
||||||
|
|
||||||
|
- `train/diffusion_loss`: Weighted diffusion loss
|
||||||
|
- `train/diffusion_accuracy`: Accuracy on masked tokens
|
||||||
|
- `train/diffusion_mask_ratio`: Average fraction of tokens masked
|
||||||
|
- `train/diffusion_num_masked_tokens`: Number of tokens masked
|
||||||
|
- `train/diffusion_avg_p_mask`: Average masking probability
|
||||||
|
- `train/diffusion_ce_loss`: Unweighted cross-entropy loss
|
||||||
|
- `train/diffusion_importance_weight_avg`: Average importance weight
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- No flash attention support
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||||
|
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||||
10
src/axolotl/integrations/diffusion/__init__.py
Normal file
10
src/axolotl/integrations/diffusion/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Diffusion LM training plugin for Axolotl.
|
||||||
|
|
||||||
|
This plugin enables diffusion language model training using the LLaDA approach.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .args import DiffusionArgs
|
||||||
|
from .plugin import DiffusionPlugin
|
||||||
|
|
||||||
|
__all__ = ["DiffusionArgs", "DiffusionPlugin"]
|
||||||
43
src/axolotl/integrations/diffusion/args.py
Normal file
43
src/axolotl/integrations/diffusion/args.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Configuration arguments for diffusion LM training."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionArgs(BaseModel):
|
||||||
|
"""Arguments for diffusion LM training plugin."""
|
||||||
|
|
||||||
|
# Noise schedule configuration
|
||||||
|
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||||
|
default="linear", description="Type of noise schedule for diffusion training"
|
||||||
|
)
|
||||||
|
min_mask_ratio: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum masking ratio for diffusion noise schedule",
|
||||||
|
)
|
||||||
|
max_mask_ratio: float = Field(
|
||||||
|
default=0.9,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Maximum masking ratio for diffusion noise schedule",
|
||||||
|
)
|
||||||
|
num_diffusion_steps: int = Field(
|
||||||
|
default=1000, ge=1, description="Number of diffusion timesteps"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward process parameters
|
||||||
|
eps: float = Field(
|
||||||
|
default=1e-3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Epsilon value for minimum masking probability in forward process",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
importance_weighting: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Apply importance weighting to loss based on masking probability",
|
||||||
|
)
|
||||||
40
src/axolotl/integrations/diffusion/plugin.py
Normal file
40
src/axolotl/integrations/diffusion/plugin.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Diffusion LM training plugin for Axolotl."""
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .trainer import DiffusionTrainer
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for diffusion language model training.
|
||||||
|
|
||||||
|
This plugin enables diffusion-based training using the LLaDA approach, which uses
|
||||||
|
random masking and bidirectional attention to train language models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = None
|
||||||
|
|
||||||
|
def get_input_args(self) -> str:
|
||||||
|
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||||
|
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||||
|
|
||||||
|
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel):
|
||||||
|
"""Perform actions after model is loaded."""
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
|
||||||
|
"""Return custom trainer class for diffusion training."""
|
||||||
|
return DiffusionTrainer
|
||||||
|
|
||||||
|
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||||
|
"""Configure trainer after creation."""
|
||||||
|
trainer.set_config(cfg)
|
||||||
245
src/axolotl/integrations/diffusion/trainer.py
Normal file
245
src/axolotl/integrations/diffusion/trainer.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Custom trainer for diffusion LM training."""
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||||
|
"""Custom trainer for diffusion LM training that overrides loss computation."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.config = None
|
||||||
|
|
||||||
|
def set_config(self, config: DictDefault):
|
||||||
|
"""Set config for diffusion training."""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward_process(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-3,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||||
|
masked with probability determined by the configured noise schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len].
|
||||||
|
eps: Small epsilon value for minimum masking probability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
noisy_batch: Input with some tokens masked.
|
||||||
|
masked_indices: Boolean mask indicating which tokens were masked.
|
||||||
|
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Sample random timesteps for each sample in batch
|
||||||
|
t = torch.rand(batch_size, device=device)
|
||||||
|
|
||||||
|
# Calculate masking probability with epsilon
|
||||||
|
p_mask = (1 - eps) * t + eps # [batch_size]
|
||||||
|
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
|
||||||
|
|
||||||
|
# Don't mask padding tokens if attention_mask is provided
|
||||||
|
if attention_mask is not None:
|
||||||
|
valid_mask = attention_mask.bool()
|
||||||
|
p_mask = p_mask * valid_mask.float()
|
||||||
|
|
||||||
|
# Create random mask based on probability
|
||||||
|
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
masked_indices = masked_indices & attention_mask.bool()
|
||||||
|
|
||||||
|
# Get tokenizer
|
||||||
|
tokenizer = self.processing_class
|
||||||
|
assert tokenizer is not None, "Tokenizer not available on Trainer object."
|
||||||
|
|
||||||
|
# Get mask token ID
|
||||||
|
mask_token_id = getattr(tokenizer, "mask_token_id", None)
|
||||||
|
if mask_token_id is None:
|
||||||
|
mask_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||||
|
|
||||||
|
# Create masked input using configured mask token
|
||||||
|
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||||
|
|
||||||
|
return noisy_batch, masked_indices, p_mask
|
||||||
|
|
||||||
|
def create_bidirectional_attention_mask(
|
||||||
|
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create bidirectional attention mask to override default causal masking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
|
# Create bidirectional attention mask to override default causal masking
|
||||||
|
# Shape: [batch_size, 1, seq_len, seq_len]
|
||||||
|
bidirectional_mask = torch.ones(
|
||||||
|
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
|
||||||
|
)
|
||||||
|
bidirectional_mask = (
|
||||||
|
bidirectional_mask.unsqueeze(0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_size, 1, seq_len, seq_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply padding mask if provided
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Convert attention_mask to 4D and apply
|
||||||
|
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
|
||||||
|
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
|
||||||
|
|
||||||
|
bidirectional_mask = (
|
||||||
|
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
|
||||||
|
)
|
||||||
|
|
||||||
|
return bidirectional_mask
|
||||||
|
|
||||||
|
def compute_diffusion_loss(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Compute diffusion loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to compute loss for.
|
||||||
|
input_ids: Ground truth token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss: Cross-entropy loss.
|
||||||
|
metrics: Dictionary of metrics.
|
||||||
|
"""
|
||||||
|
# Apply forward process
|
||||||
|
noisy_batch, masked_indices, p_mask = self.forward_process(
|
||||||
|
input_ids, attention_mask, self.config.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create bidirectional attention mask (always required for diffusion training)
|
||||||
|
bidirectional_mask = self.create_bidirectional_attention_mask(
|
||||||
|
input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids=noisy_batch,
|
||||||
|
attention_mask=bidirectional_mask,
|
||||||
|
)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Apply attention mask to masked_indices if provided
|
||||||
|
if attention_mask is not None:
|
||||||
|
loss_mask = masked_indices & attention_mask.bool()
|
||||||
|
else:
|
||||||
|
loss_mask = masked_indices
|
||||||
|
|
||||||
|
if loss_mask.sum() > 0:
|
||||||
|
valid_indices = torch.where(loss_mask)
|
||||||
|
batch_indices, seq_indices = valid_indices
|
||||||
|
|
||||||
|
# Extract the relevant data
|
||||||
|
masked_logits = logits[
|
||||||
|
batch_indices, seq_indices
|
||||||
|
] # [num_masked_tokens, vocab_size]
|
||||||
|
masked_targets = input_ids[
|
||||||
|
batch_indices, seq_indices
|
||||||
|
] # [num_masked_tokens]
|
||||||
|
masked_p_mask = p_mask[batch_indices, seq_indices] # [num_masked_tokens]
|
||||||
|
|
||||||
|
# Compute cross-entropy loss without reduction (cast to fp32 for stability)
|
||||||
|
token_loss = F.cross_entropy(
|
||||||
|
masked_logits.float(), masked_targets, reduction="none"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply importance weighting if enabled
|
||||||
|
if self.config.importance_weighting:
|
||||||
|
masked_p_mask = masked_p_mask.float()
|
||||||
|
weighted_loss = token_loss / masked_p_mask
|
||||||
|
else:
|
||||||
|
weighted_loss = token_loss
|
||||||
|
|
||||||
|
# Final loss: sum weighted losses, normalize by total tokens
|
||||||
|
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
|
||||||
|
ce_loss = token_loss.mean()
|
||||||
|
|
||||||
|
# Compute accuracy on masked tokens
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_tokens = masked_logits.argmax(dim=-1)
|
||||||
|
accuracy = (pred_tokens == masked_targets).float().mean()
|
||||||
|
else:
|
||||||
|
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||||
|
accuracy = torch.tensor(0.0, device=input_ids.device)
|
||||||
|
ce_loss = torch.tensor(0.0, device=input_ids.device)
|
||||||
|
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"accuracy": accuracy.item(),
|
||||||
|
"mask_ratio": masked_indices.float().mean().item(),
|
||||||
|
"num_masked_tokens": loss_mask.sum().item(),
|
||||||
|
"avg_p_mask": (
|
||||||
|
p_mask[masked_indices].mean().item()
|
||||||
|
if masked_indices.sum() > 0
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
|
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.importance_weighting:
|
||||||
|
metrics["importance_weight_avg"] = (
|
||||||
|
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
return_outputs: bool = False,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
||||||
|
"""Override compute_loss to use diffusion loss."""
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError("input_ids is required for diffusion training")
|
||||||
|
|
||||||
|
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
if self.state.is_local_process_zero:
|
||||||
|
for key, value in metrics.items():
|
||||||
|
self.log({f"train/diffusion_{key}": value})
|
||||||
|
|
||||||
|
if return_outputs:
|
||||||
|
# TODO: compute outputs (?)
|
||||||
|
outputs = [loss]
|
||||||
|
return (loss, outputs)
|
||||||
|
|
||||||
|
return loss
|
||||||
Reference in New Issue
Block a user