fixes + improvements
This commit is contained in:
@@ -22,8 +22,8 @@ importance_weighting: true
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
@@ -51,8 +51,8 @@ eval_steps: 1000
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_project: diffusion-plugin
|
||||
wandb_entity: axolotl-ai
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -41,3 +41,7 @@ class DiffusionArgs(BaseModel):
|
||||
default=True,
|
||||
description="Apply importance weighting to loss based on masking probability",
|
||||
)
|
||||
mask_token_id: int = Field(
|
||||
default=128002,
|
||||
description="Token ID to use for masking. Default is 128002 (<|reserved_special_token_0|> for Llama 3.2)",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Custom trainer for diffusion LM training."""
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel
|
||||
from torch import nn
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -19,17 +17,37 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.config = config
|
||||
self._cache_special_token_ids()
|
||||
|
||||
def _cache_special_token_ids(self):
|
||||
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||
if self.processing_class is None:
|
||||
self._special_token_ids = set()
|
||||
return
|
||||
|
||||
tokenizer = self.processing_class
|
||||
special_tokens = set()
|
||||
|
||||
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
|
||||
special_tokens.add(tokenizer.bos_token_id)
|
||||
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
special_tokens.add(tokenizer.eos_token_id)
|
||||
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||
special_tokens.add(tokenizer.pad_token_id)
|
||||
|
||||
self._special_token_ids = special_tokens
|
||||
|
||||
def forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
eps: float = 1e-3,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||
masked with probability determined by the configured noise schedule.
|
||||
@@ -59,19 +77,20 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
valid_mask = attention_mask.bool()
|
||||
p_mask = p_mask * valid_mask.float()
|
||||
|
||||
# Create random mask based on probability
|
||||
# Create mask to exclude special tokens (BOS, EOS, PAD) using cached IDs
|
||||
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
if self._special_token_ids:
|
||||
for token_id in self._special_token_ids:
|
||||
special_token_mask |= input_ids == token_id
|
||||
|
||||
# Create random mask based on probability, excluding special tokens
|
||||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||
masked_indices = masked_indices & ~special_token_mask
|
||||
if attention_mask is not None:
|
||||
masked_indices = masked_indices & attention_mask.bool()
|
||||
|
||||
# Get tokenizer
|
||||
tokenizer = self.processing_class
|
||||
assert tokenizer is not None, "Tokenizer not available on Trainer object."
|
||||
|
||||
# Get mask token ID
|
||||
mask_token_id = getattr(tokenizer, "mask_token_id", None)
|
||||
if mask_token_id is None:
|
||||
mask_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||
# Get mask token ID from config
|
||||
mask_token_id = self.config.mask_token_id
|
||||
|
||||
# Create masked input using configured mask token
|
||||
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||
@@ -79,49 +98,47 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
def create_bidirectional_attention_mask(
|
||||
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create bidirectional attention mask to override default causal masking.
|
||||
Handles sample-packed sequences where different samples are identified
|
||||
by different attention mask values.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Create bidirectional attention mask to override default causal masking
|
||||
# Shape: [batch_size, 1, seq_len, seq_len]
|
||||
bidirectional_mask = torch.ones(
|
||||
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
|
||||
)
|
||||
bidirectional_mask = (
|
||||
bidirectional_mask.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, 1, seq_len, seq_len)
|
||||
)
|
||||
|
||||
# Apply padding mask if provided
|
||||
if attention_mask is not None:
|
||||
# Convert attention_mask to 4D and apply
|
||||
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
|
||||
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
|
||||
|
||||
bidirectional_mask = (
|
||||
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
|
||||
if attention_mask is None or not self.config.sample_packing:
|
||||
# Simple case: no attention mask, allow all-to-all attention
|
||||
return torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Create attention mask by comparing sample IDs element-wise
|
||||
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
|
||||
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
|
||||
|
||||
# Tokens can attend to each other if they have the same non-zero sample ID
|
||||
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
bidirectional_mask = bidirectional_mask.unsqueeze(1)
|
||||
|
||||
return bidirectional_mask
|
||||
|
||||
def compute_diffusion_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
"""
|
||||
Compute diffusion loss.
|
||||
|
||||
@@ -139,7 +156,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
input_ids, attention_mask, self.config.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask (always required for diffusion training)
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = self.create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
@@ -151,14 +168,8 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
# Apply attention mask to masked_indices if provided
|
||||
if attention_mask is not None:
|
||||
loss_mask = masked_indices & attention_mask.bool()
|
||||
else:
|
||||
loss_mask = masked_indices
|
||||
|
||||
if loss_mask.sum() > 0:
|
||||
valid_indices = torch.where(loss_mask)
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
batch_indices, seq_indices = valid_indices
|
||||
|
||||
# Extract the relevant data
|
||||
@@ -200,29 +211,23 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
"loss": loss.item(),
|
||||
"accuracy": accuracy.item(),
|
||||
"mask_ratio": masked_indices.float().mean().item(),
|
||||
"num_masked_tokens": loss_mask.sum().item(),
|
||||
"avg_p_mask": (
|
||||
p_mask[masked_indices].mean().item()
|
||||
if masked_indices.sum() > 0
|
||||
else 0.0
|
||||
),
|
||||
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
|
||||
"num_masked_tokens": masked_indices.sum().item(),
|
||||
"avg_p_mask": p_mask[masked_indices].mean().item(),
|
||||
"ce_loss": ce_loss.item(),
|
||||
}
|
||||
|
||||
if self.config.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (
|
||||
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
|
||||
)
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
@@ -232,10 +237,10 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
||||
|
||||
# Log metrics
|
||||
if self.state.is_local_process_zero:
|
||||
for key, value in metrics.items():
|
||||
self.log({f"train/diffusion_{key}": value})
|
||||
# # Log metrics
|
||||
# if self.state.is_local_process_zero:
|
||||
# for key, value in metrics.items():
|
||||
# self.log({f"train/diffusion_{key}": value})
|
||||
|
||||
if return_outputs:
|
||||
# TODO: compute outputs (?)
|
||||
|
||||
Reference in New Issue
Block a user