diffusion training plugin
This commit is contained in:
60
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
60
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
@@ -0,0 +1,60 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Dataset configuration for pretraining
|
||||
datasets:
|
||||
- path: wikitext
|
||||
name: wikitext-103-raw-v1
|
||||
type: completion
|
||||
field: text
|
||||
val_set_size: 0.001
|
||||
|
||||
plugins:
|
||||
- diffusion.DiffusionPlugin
|
||||
noise_schedule: "cosine"
|
||||
min_mask_ratio: 0.15
|
||||
max_mask_ratio: 0.85
|
||||
num_diffusion_steps: 2000
|
||||
eps: 5e-4
|
||||
importance_weighting: true
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
max_steps: 10000
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-4
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_steps: 500
|
||||
|
||||
save_strategy: steps
|
||||
eval_strategy: steps
|
||||
save_steps: 1000
|
||||
eval_steps: 1000
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
55
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
55
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
@@ -0,0 +1,55 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
val_set_size: 0.05
|
||||
|
||||
plugins:
|
||||
- diffusion.DiffusionPlugin
|
||||
noise_schedule: "linear"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 1000
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
save_strategy: steps
|
||||
eval_strategy: steps
|
||||
save_steps: 500
|
||||
eval_steps: 500
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -385,10 +385,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
|
||||
# Check if trainer class inherits from transformers.Trainer
|
||||
# If so, we should pass the tokenizer/processing_class even if not in direct signature
|
||||
from transformers import Trainer as HFTrainer
|
||||
|
||||
if "processing_class" in sig.parameters:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
elif issubclass(trainer_cls, HFTrainer):
|
||||
# For subclasses of transformers.Trainer, try processing_class first (newer HF versions)
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
if (
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
|
||||
117
src/axolotl/integrations/diffusion/README.md
Normal file
117
src/axolotl/integrations/diffusion/README.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Diffusion LM Training Plugin for Axolotl
|
||||
|
||||
This plugin enables diffusion language model training using the LLaDA (Large Language
|
||||
And Diffusion Assistant) approach within the Axolotl framework.
|
||||
|
||||
## Overview
|
||||
|
||||
LLaDA is a diffusion-based approach to language model training that uses:
|
||||
- **Random token masking** during training instead of next-token prediction
|
||||
- **Bidirectional attention** to allow the model to see the full context
|
||||
- **Importance weighting** based on masking probabilities for stable training
|
||||
|
||||
This approach can lead to more robust language models with better understanding of
|
||||
bidirectional context.
|
||||
|
||||
## Installation
|
||||
|
||||
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
|
||||
your training config.
|
||||
|
||||
## Quickstart
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add the following to your Axolotl configuration YAML:
|
||||
|
||||
```yaml
|
||||
# Enable diffusion LM training plugin
|
||||
plugins:
|
||||
- diffusion.DiffusionPlugin
|
||||
|
||||
# Diffusion-specific configuration
|
||||
noise_schedule: "linear" # or "cosine"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 1000
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
# Model configuration
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
model_type: llama
|
||||
|
||||
# Standard Axolotl configuration
|
||||
datasets:
|
||||
- path: your_dataset
|
||||
type: completion # or conversation
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 8
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 3e-4
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any models that support 4D attention masks should work out of the box. If not, please
|
||||
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
|
||||
|
||||
## How It Works
|
||||
|
||||
### Random Masking
|
||||
During training, tokens are randomly masked based on a sampled timestep:
|
||||
- Sample timestep `t` uniformly from [0, 1]
|
||||
- Calculate masking probability: `p = (1 - eps) * t + eps`
|
||||
- Randomly mask tokens with probability `p`
|
||||
|
||||
### Bidirectional Attention
|
||||
The plugin uses native 4D attention masks to:
|
||||
- Enable bidirectional attention without patches
|
||||
- Allow all tokens to attend to all other tokens
|
||||
- Maintain proper padding masks
|
||||
- Work with modern `transformers` models out of the box
|
||||
|
||||
### Diffusion Loss
|
||||
|
||||
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||
|
||||
```
|
||||
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
### Memory Optimization
|
||||
- Bidirectional attention uses more memory than causal attention
|
||||
- Consider reducing `micro_batch_size` if you encounter OOM errors
|
||||
- Consider using gradient checkpointing, torch.compile,
|
||||
|
||||
### Training Stability
|
||||
- Start with `noise_schedule: "linear"` for more predictable behavior
|
||||
- Enable `importance_weighting` for better gradient scaling
|
||||
|
||||
### Convergence
|
||||
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
|
||||
- Expect different loss curves compared to standard language modeling
|
||||
|
||||
## Metrics and Monitoring
|
||||
|
||||
The plugin adds several metrics to track diffusion training:
|
||||
|
||||
- `train/diffusion_loss`: Weighted diffusion loss
|
||||
- `train/diffusion_accuracy`: Accuracy on masked tokens
|
||||
- `train/diffusion_mask_ratio`: Average fraction of tokens masked
|
||||
- `train/diffusion_num_masked_tokens`: Number of tokens masked
|
||||
- `train/diffusion_avg_p_mask`: Average masking probability
|
||||
- `train/diffusion_ce_loss`: Unweighted cross-entropy loss
|
||||
- `train/diffusion_importance_weight_avg`: Average importance weight
|
||||
|
||||
## Limitations
|
||||
|
||||
- No flash attention support
|
||||
|
||||
## References
|
||||
|
||||
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||
- [Axolotl Documentation](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
10
src/axolotl/integrations/diffusion/__init__.py
Normal file
10
src/axolotl/integrations/diffusion/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Diffusion LM training plugin for Axolotl.
|
||||
|
||||
This plugin enables diffusion language model training using the LLaDA approach.
|
||||
"""
|
||||
|
||||
from .args import DiffusionArgs
|
||||
from .plugin import DiffusionPlugin
|
||||
|
||||
__all__ = ["DiffusionArgs", "DiffusionPlugin"]
|
||||
43
src/axolotl/integrations/diffusion/args.py
Normal file
43
src/axolotl/integrations/diffusion/args.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Configuration arguments for diffusion LM training."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DiffusionArgs(BaseModel):
|
||||
"""Arguments for diffusion LM training plugin."""
|
||||
|
||||
# Noise schedule configuration
|
||||
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||
default="linear", description="Type of noise schedule for diffusion training"
|
||||
)
|
||||
min_mask_ratio: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
max_mask_ratio: float = Field(
|
||||
default=0.9,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
num_diffusion_steps: int = Field(
|
||||
default=1000, ge=1, description="Number of diffusion timesteps"
|
||||
)
|
||||
|
||||
# Forward process parameters
|
||||
eps: float = Field(
|
||||
default=1e-3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Epsilon value for minimum masking probability in forward process",
|
||||
)
|
||||
|
||||
# Training configuration
|
||||
importance_weighting: bool = Field(
|
||||
default=True,
|
||||
description="Apply importance weighting to loss based on masking probability",
|
||||
)
|
||||
40
src/axolotl/integrations/diffusion/plugin.py
Normal file
40
src/axolotl/integrations/diffusion/plugin.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Diffusion LM training plugin for Axolotl."""
|
||||
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .trainer import DiffusionTrainer
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for diffusion language model training.
|
||||
|
||||
This plugin enables diffusion-based training using the LLaDA approach, which uses
|
||||
random masking and bidirectional attention to train language models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cfg = None
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel):
|
||||
"""Perform actions after model is loaded."""
|
||||
self.cfg = cfg
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
|
||||
"""Return custom trainer class for diffusion training."""
|
||||
return DiffusionTrainer
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Configure trainer after creation."""
|
||||
trainer.set_config(cfg)
|
||||
245
src/axolotl/integrations/diffusion/trainer.py
Normal file
245
src/axolotl/integrations/diffusion/trainer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Custom trainer for diffusion LM training."""
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
"""Custom trainer for diffusion LM training that overrides loss computation."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.config = config
|
||||
|
||||
def forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-3,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||
masked with probability determined by the configured noise schedule.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
eps: Small epsilon value for minimum masking probability.
|
||||
|
||||
Returns:
|
||||
noisy_batch: Input with some tokens masked.
|
||||
masked_indices: Boolean mask indicating which tokens were masked.
|
||||
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Sample random timesteps for each sample in batch
|
||||
t = torch.rand(batch_size, device=device)
|
||||
|
||||
# Calculate masking probability with epsilon
|
||||
p_mask = (1 - eps) * t + eps # [batch_size]
|
||||
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
|
||||
|
||||
# Don't mask padding tokens if attention_mask is provided
|
||||
if attention_mask is not None:
|
||||
valid_mask = attention_mask.bool()
|
||||
p_mask = p_mask * valid_mask.float()
|
||||
|
||||
# Create random mask based on probability
|
||||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||
if attention_mask is not None:
|
||||
masked_indices = masked_indices & attention_mask.bool()
|
||||
|
||||
# Get tokenizer
|
||||
tokenizer = self.processing_class
|
||||
assert tokenizer is not None, "Tokenizer not available on Trainer object."
|
||||
|
||||
# Get mask token ID
|
||||
mask_token_id = getattr(tokenizer, "mask_token_id", None)
|
||||
if mask_token_id is None:
|
||||
mask_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||
|
||||
# Create masked input using configured mask token
|
||||
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
def create_bidirectional_attention_mask(
|
||||
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create bidirectional attention mask to override default causal masking.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
|
||||
Returns:
|
||||
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Create bidirectional attention mask to override default causal masking
|
||||
# Shape: [batch_size, 1, seq_len, seq_len]
|
||||
bidirectional_mask = torch.ones(
|
||||
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
|
||||
)
|
||||
bidirectional_mask = (
|
||||
bidirectional_mask.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, 1, seq_len, seq_len)
|
||||
)
|
||||
|
||||
# Apply padding mask if provided
|
||||
if attention_mask is not None:
|
||||
# Convert attention_mask to 4D and apply
|
||||
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
|
||||
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
|
||||
|
||||
bidirectional_mask = (
|
||||
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
|
||||
)
|
||||
|
||||
return bidirectional_mask
|
||||
|
||||
def compute_diffusion_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Compute diffusion loss.
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for.
|
||||
input_ids: Ground truth token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
|
||||
Returns:
|
||||
loss: Cross-entropy loss.
|
||||
metrics: Dictionary of metrics.
|
||||
"""
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self.forward_process(
|
||||
input_ids, attention_mask, self.config.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask (always required for diffusion training)
|
||||
bidirectional_mask = self.create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(
|
||||
input_ids=noisy_batch,
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Apply attention mask to masked_indices if provided
|
||||
if attention_mask is not None:
|
||||
loss_mask = masked_indices & attention_mask.bool()
|
||||
else:
|
||||
loss_mask = masked_indices
|
||||
|
||||
if loss_mask.sum() > 0:
|
||||
valid_indices = torch.where(loss_mask)
|
||||
batch_indices, seq_indices = valid_indices
|
||||
|
||||
# Extract the relevant data
|
||||
masked_logits = logits[
|
||||
batch_indices, seq_indices
|
||||
] # [num_masked_tokens, vocab_size]
|
||||
masked_targets = input_ids[
|
||||
batch_indices, seq_indices
|
||||
] # [num_masked_tokens]
|
||||
masked_p_mask = p_mask[batch_indices, seq_indices] # [num_masked_tokens]
|
||||
|
||||
# Compute cross-entropy loss without reduction (cast to fp32 for stability)
|
||||
token_loss = F.cross_entropy(
|
||||
masked_logits.float(), masked_targets, reduction="none"
|
||||
)
|
||||
|
||||
# Apply importance weighting if enabled
|
||||
if self.config.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
weighted_loss = token_loss
|
||||
|
||||
# Final loss: sum weighted losses, normalize by total tokens
|
||||
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
|
||||
ce_loss = token_loss.mean()
|
||||
|
||||
# Compute accuracy on masked tokens
|
||||
with torch.no_grad():
|
||||
pred_tokens = masked_logits.argmax(dim=-1)
|
||||
accuracy = (pred_tokens == masked_targets).float().mean()
|
||||
else:
|
||||
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||
accuracy = torch.tensor(0.0, device=input_ids.device)
|
||||
ce_loss = torch.tensor(0.0, device=input_ids.device)
|
||||
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
|
||||
|
||||
metrics = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": accuracy.item(),
|
||||
"mask_ratio": masked_indices.float().mean().item(),
|
||||
"num_masked_tokens": loss_mask.sum().item(),
|
||||
"avg_p_mask": (
|
||||
p_mask[masked_indices].mean().item()
|
||||
if masked_indices.sum() > 0
|
||||
else 0.0
|
||||
),
|
||||
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
|
||||
}
|
||||
|
||||
if self.config.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (
|
||||
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
|
||||
)
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids is required for diffusion training")
|
||||
|
||||
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
||||
|
||||
# Log metrics
|
||||
if self.state.is_local_process_zero:
|
||||
for key, value in metrics.items():
|
||||
self.log({f"train/diffusion_{key}": value})
|
||||
|
||||
if return_outputs:
|
||||
# TODO: compute outputs (?)
|
||||
outputs = [loss]
|
||||
return (loss, outputs)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user