fixes + improvements
This commit is contained in:
@@ -22,8 +22,8 @@ importance_weighting: true
|
|||||||
output_dir: ./outputs/model-out
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: true
|
sample_packing: false
|
||||||
eval_sample_packing: true
|
eval_sample_packing: false
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
@@ -51,8 +51,8 @@ eval_steps: 1000
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
wandb_project:
|
wandb_project: diffusion-plugin
|
||||||
wandb_entity:
|
wandb_entity: axolotl-ai
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -41,3 +41,7 @@ class DiffusionArgs(BaseModel):
|
|||||||
default=True,
|
default=True,
|
||||||
description="Apply importance weighting to loss based on masking probability",
|
description="Apply importance weighting to loss based on masking probability",
|
||||||
)
|
)
|
||||||
|
mask_token_id: int = Field(
|
||||||
|
default=128002,
|
||||||
|
description="Token ID to use for masking. Default is 128002 (<|reserved_special_token_0|> for Llama 3.2)",
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
"""Custom trainer for diffusion LM training."""
|
"""Custom trainer for diffusion LM training."""
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PreTrainedModel
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -19,17 +17,37 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self._special_token_ids = None
|
||||||
|
|
||||||
def set_config(self, config: DictDefault):
|
def set_config(self, config: DictDefault):
|
||||||
"""Set config for diffusion training."""
|
"""Set config for diffusion training."""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self._cache_special_token_ids()
|
||||||
|
|
||||||
|
def _cache_special_token_ids(self):
|
||||||
|
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||||
|
if self.processing_class is None:
|
||||||
|
self._special_token_ids = set()
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = self.processing_class
|
||||||
|
special_tokens = set()
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.bos_token_id)
|
||||||
|
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.eos_token_id)
|
||||||
|
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
self._special_token_ids = special_tokens
|
||||||
|
|
||||||
def forward_process(
|
def forward_process(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
eps: float = 1e-3,
|
eps: float = 1e-3,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward noising process. A timestep is sampled along the process, and tokens are
|
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||||
masked with probability determined by the configured noise schedule.
|
masked with probability determined by the configured noise schedule.
|
||||||
@@ -59,19 +77,20 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
valid_mask = attention_mask.bool()
|
valid_mask = attention_mask.bool()
|
||||||
p_mask = p_mask * valid_mask.float()
|
p_mask = p_mask * valid_mask.float()
|
||||||
|
|
||||||
# Create random mask based on probability
|
# Create mask to exclude special tokens (BOS, EOS, PAD) using cached IDs
|
||||||
|
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||||
|
if self._special_token_ids:
|
||||||
|
for token_id in self._special_token_ids:
|
||||||
|
special_token_mask |= input_ids == token_id
|
||||||
|
|
||||||
|
# Create random mask based on probability, excluding special tokens
|
||||||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||||
|
masked_indices = masked_indices & ~special_token_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
masked_indices = masked_indices & attention_mask.bool()
|
masked_indices = masked_indices & attention_mask.bool()
|
||||||
|
|
||||||
# Get tokenizer
|
# Get mask token ID from config
|
||||||
tokenizer = self.processing_class
|
mask_token_id = self.config.mask_token_id
|
||||||
assert tokenizer is not None, "Tokenizer not available on Trainer object."
|
|
||||||
|
|
||||||
# Get mask token ID
|
|
||||||
mask_token_id = getattr(tokenizer, "mask_token_id", None)
|
|
||||||
if mask_token_id is None:
|
|
||||||
mask_token_id = getattr(tokenizer, "unk_token_id", None)
|
|
||||||
|
|
||||||
# Create masked input using configured mask token
|
# Create masked input using configured mask token
|
||||||
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||||
@@ -79,49 +98,47 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
return noisy_batch, masked_indices, p_mask
|
return noisy_batch, masked_indices, p_mask
|
||||||
|
|
||||||
def create_bidirectional_attention_mask(
|
def create_bidirectional_attention_mask(
|
||||||
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Create bidirectional attention mask to override default causal masking.
|
Create bidirectional attention mask to override default causal masking.
|
||||||
|
Handles sample-packed sequences where different samples are identified
|
||||||
|
by different attention mask values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids: Input token ids [batch_size, seq_len].
|
input_ids: Input token ids [batch_size, seq_len].
|
||||||
attention_mask: Attention mask [batch_size, seq_len].
|
attention_mask: Attention mask [batch_size, seq_len]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
||||||
"""
|
"""
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
# Create bidirectional attention mask to override default causal masking
|
if attention_mask is None or not self.config.sample_packing:
|
||||||
# Shape: [batch_size, 1, seq_len, seq_len]
|
# Simple case: no attention mask, allow all-to-all attention
|
||||||
bidirectional_mask = torch.ones(
|
return torch.ones(
|
||||||
seq_len, seq_len, dtype=torch.bool, device=input_ids.device
|
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||||
)
|
|
||||||
bidirectional_mask = (
|
|
||||||
bidirectional_mask.unsqueeze(0)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(batch_size, 1, seq_len, seq_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply padding mask if provided
|
|
||||||
if attention_mask is not None:
|
|
||||||
# Convert attention_mask to 4D and apply
|
|
||||||
expanded_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2)
|
|
||||||
expanded_mask = expanded_mask.expand(batch_size, 1, seq_len, seq_len)
|
|
||||||
|
|
||||||
bidirectional_mask = (
|
|
||||||
bidirectional_mask & expanded_mask & expanded_mask.transpose(-1, -2)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create attention mask by comparing sample IDs element-wise
|
||||||
|
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
|
||||||
|
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
|
||||||
|
|
||||||
|
# Tokens can attend to each other if they have the same non-zero sample ID
|
||||||
|
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
|
||||||
|
|
||||||
|
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||||
|
bidirectional_mask = bidirectional_mask.unsqueeze(1)
|
||||||
|
|
||||||
return bidirectional_mask
|
return bidirectional_mask
|
||||||
|
|
||||||
def compute_diffusion_loss(
|
def compute_diffusion_loss(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
model: nn.Module,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Compute diffusion loss.
|
Compute diffusion loss.
|
||||||
|
|
||||||
@@ -139,7 +156,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
input_ids, attention_mask, self.config.eps
|
input_ids, attention_mask, self.config.eps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create bidirectional attention mask (always required for diffusion training)
|
# Create bidirectional attention mask
|
||||||
bidirectional_mask = self.create_bidirectional_attention_mask(
|
bidirectional_mask = self.create_bidirectional_attention_mask(
|
||||||
input_ids, attention_mask
|
input_ids, attention_mask
|
||||||
)
|
)
|
||||||
@@ -151,14 +168,8 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
)
|
)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# Apply attention mask to masked_indices if provided
|
if masked_indices.sum() > 0:
|
||||||
if attention_mask is not None:
|
valid_indices = torch.where(masked_indices)
|
||||||
loss_mask = masked_indices & attention_mask.bool()
|
|
||||||
else:
|
|
||||||
loss_mask = masked_indices
|
|
||||||
|
|
||||||
if loss_mask.sum() > 0:
|
|
||||||
valid_indices = torch.where(loss_mask)
|
|
||||||
batch_indices, seq_indices = valid_indices
|
batch_indices, seq_indices = valid_indices
|
||||||
|
|
||||||
# Extract the relevant data
|
# Extract the relevant data
|
||||||
@@ -200,29 +211,23 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"accuracy": accuracy.item(),
|
"accuracy": accuracy.item(),
|
||||||
"mask_ratio": masked_indices.float().mean().item(),
|
"mask_ratio": masked_indices.float().mean().item(),
|
||||||
"num_masked_tokens": loss_mask.sum().item(),
|
"num_masked_tokens": masked_indices.sum().item(),
|
||||||
"avg_p_mask": (
|
"avg_p_mask": p_mask[masked_indices].mean().item(),
|
||||||
p_mask[masked_indices].mean().item()
|
"ce_loss": ce_loss.item(),
|
||||||
if masked_indices.sum() > 0
|
|
||||||
else 0.0
|
|
||||||
),
|
|
||||||
"ce_loss": ce_loss.item() if loss_mask.sum() > 0 else 0.0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.config.importance_weighting:
|
if self.config.importance_weighting:
|
||||||
metrics["importance_weight_avg"] = (
|
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||||
(1.0 / masked_p_mask).mean().item() if loss_mask.sum() > 0 else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
model: nn.Module,
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: dict[str, torch.Tensor],
|
||||||
return_outputs: bool = False,
|
return_outputs: bool = False,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: torch.Tensor | None = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
"""Override compute_loss to use diffusion loss."""
|
"""Override compute_loss to use diffusion loss."""
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
@@ -232,10 +237,10 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
|||||||
|
|
||||||
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
||||||
|
|
||||||
# Log metrics
|
# # Log metrics
|
||||||
if self.state.is_local_process_zero:
|
# if self.state.is_local_process_zero:
|
||||||
for key, value in metrics.items():
|
# for key, value in metrics.items():
|
||||||
self.log({f"train/diffusion_{key}": value})
|
# self.log({f"train/diffusion_{key}": value})
|
||||||
|
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
# TODO: compute outputs (?)
|
# TODO: compute outputs (?)
|
||||||
|
|||||||
Reference in New Issue
Block a user