fixes + improvements

This commit is contained in:
Dan Saunders
2025-08-14 16:11:37 -04:00
parent 0a9341acde
commit 479a454ae3
3 changed files with 78 additions and 69 deletions

View File

@@ -22,8 +22,8 @@ importance_weighting: true
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
sample_packing: false
eval_sample_packing: false
gradient_accumulation_steps: 8
micro_batch_size: 4
@@ -51,8 +51,8 @@ eval_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_project: diffusion-plugin
wandb_entity: axolotl-ai
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -41,3 +41,7 @@ class DiffusionArgs(BaseModel):
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int = Field(
default=128002,
description="Token ID to use for masking. Default is 128002 (<|reserved_special_token_0|> for Llama 3.2)",
)

View File

@@ -1,10 +1,8 @@
"""Custom trainer for diffusion LM training."""
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
@@ -19,17 +17,37 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._cache_special_token_ids()
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def forward_process(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor | None = None,
eps: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
@@ -59,19 +77,20 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create random mask based on probability
# Create mask to exclude special tokens (BOS, EOS, PAD) using cached IDs
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on probability, excluding special tokens
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# Get tokenizer
tokenizer = self.processing_class
assert tokenizer is not None, "Tokenizer not available on Trainer object."
# Get mask token ID
mask_token_id = getattr(tokenizer, "mask_token_id", None)
if mask_token_id is None:
mask_token_id = getattr(tokenizer, "unk_token_id", None)
# Get mask token ID from config
mask_token_id = self.config.mask_token_id
# Create masked input using configured mask token
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
@@ -79,49 +98,47 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
return noisy_batch, masked_indices, p_mask
def create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking.
Handles sample-packed sequences where different samples are identified
by different attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Create bidirectional attention mask to override default causal masking
# Shape: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = torch.ones(
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
)
bidirectional_mask = (
bidirectional_mask.unsqueeze(0)
.unsqueeze(0)
.expand(batch_size, 1, seq_len, seq_len)
)
# Apply padding mask if provided
if attention_mask is not None:
# Convert attention_mask to 4D and apply
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
bidirectional_mask = (
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
if attention_mask is None or not self.config.sample_packing:
# Simple case: no attention mask, allow all-to-all attention
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def compute_diffusion_loss(
self,
model: PreTrainedModel,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, float]]:
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
"""
Compute diffusion loss.
@@ -139,7 +156,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
input_ids, attention_mask, self.config.eps
)
# Create bidirectional attention mask (always required for diffusion training)
# Create bidirectional attention mask
bidirectional_mask = self.create_bidirectional_attention_mask(
input_ids, attention_mask
)
@@ -151,14 +168,8 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
)
logits = outputs.logits
# Apply attention mask to masked_indices if provided
if attention_mask is not None:
loss_mask = masked_indices & attention_mask.bool()
else:
loss_mask = masked_indices
if loss_mask.sum() > 0:
valid_indices = torch.where(loss_mask)
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
# Extract the relevant data
@@ -200,29 +211,23 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": loss_mask.sum().item(),
"avg_p_mask": (
p_mask[masked_indices].mean().item()
if masked_indices.sum() > 0
else 0.0
),
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
"num_masked_tokens": masked_indices.sum().item(),
"avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(),
}
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
)
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
return loss, metrics
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
@@ -232,10 +237,10 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
# Log metrics
if self.state.is_local_process_zero:
for key, value in metrics.items():
self.log({f"train/diffusion_{key}": value})
# # Log metrics
# if self.state.is_local_process_zero:
# for key, value in metrics.items():
# self.log({f"train/diffusion_{key}": value})
if return_outputs:
# TODO: compute outputs (?)