sample generation, tests fixes

This commit is contained in:
Dan Saunders
2025-08-18 18:25:04 +00:00
parent 8569675b26
commit 556a69118f
9 changed files with 585 additions and 171 deletions

View File

@@ -147,7 +147,7 @@ class BasePlugin:
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Returns a custom class for the trainer. """Returns a custom class for the trainer.
Args: Args:

View File

@@ -27,15 +27,24 @@ Add the following to your Axolotl configuration YAML:
```yaml ```yaml
# Enable diffusion LM training plugin # Enable diffusion LM training plugin
plugins: plugins:
- diffusion.DiffusionPlugin - axolotl.integrations.diffusion.DiffusionPlugin
# Diffusion-specific configuration # Diffusion-specific configuration
noise_schedule: "linear" # or "cosine" noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1 min_mask_ratio: 0.1
max_mask_ratio: 0.9 max_mask_ratio: 0.9
num_diffusion_steps: 128 num_diffusion_steps: 128
eps: 1e-3 eps: 1e-3
importance_weighting: true importance_weighting: true
mask_token_id: 128002
# Sample generation (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
# Model configuration # Model configuration
base_model: meta-llama/Llama-3.2-1B base_model: meta-llama/Llama-3.2-1B
@@ -88,24 +97,37 @@ loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
- Consider using gradient checkpointing, torch.compile, - Consider using gradient checkpointing, torch.compile,
### Training Stability ### Training Stability
- Start with `noise_schedule: "linear"` for more predictable behavior - Start with `noise_schedule: linear` for more predictable behavior
- Enable `importance_weighting` for better gradient scaling - Enable `importance_weighting: true` for better gradient scaling
### Convergence ### Convergence
- Monitor the `diffusion_loss` and `diffusion_accuracy` metrics - Monitor the `diffusion_loss` and `diffusion_accuracy` metrics
- Expect different loss curves compared to standard language modeling - Expect different loss curves compared to standard language modeling
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
```
📝 Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Metrics and Monitoring ## Metrics and Monitoring
The plugin adds several metrics to track diffusion training: The plugin adds several metrics to track diffusion training:
- `train/diffusion_loss`: Weighted diffusion loss - `train/loss`: Weighted diffusion loss
- `train/diffusion_accuracy`: Accuracy on masked tokens - `train/accuracy`: Accuracy on masked tokens
- `train/diffusion_mask_ratio`: Average fraction of tokens masked - `train/mask_ratio`: Average fraction of tokens masked
- `train/diffusion_num_masked_tokens`: Number of tokens masked - `train/num_masked_tokens`: Number of tokens masked
- `train/diffusion_avg_p_mask`: Average masking probability - `train/avg_p_mask`: Average masking probability
- `train/diffusion_ce_loss`: Unweighted cross-entropy loss - `train/ce_loss`: Unweighted cross-entropy loss
- `train/diffusion_importance_weight_avg`: Average importance weight - `train/importance_weight_avg`: Average importance weight
## Limitations ## Limitations

View File

@@ -46,5 +46,27 @@ class DiffusionArgs(BaseModel):
description=( description=(
"Token ID to use for masking. Default is 128002 " "Token ID to use for masking. Default is 128002 "
"(<|reserved_special_token_0|> for Llama 3.2)" "(<|reserved_special_token_0|> for Llama 3.2)"
) ),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
) )

View File

@@ -0,0 +1,115 @@
"""Callbacks for diffusion training."""
import wandb
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.logging import get_logger
from .generation import generate_samples
LOG = get_logger(__name__)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
# Only generate samples at the specified interval and after step 0
if (
state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0
and hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
LOG.info(
f"Generating {self.trainer.config.num_generation_samples} samples at step {state.global_step}..."
)
# Create a simple dataloader from eval dataset for sampling
eval_dataloader = self.trainer.get_eval_dataloader()
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
val_dataloader=eval_dataloader,
num_generation_samples=self.trainer.config.num_generation_samples,
max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=self.trainer.config.generation_steps,
temperature=self.trainer.config.generation_temperature,
mask_token_id=self.trainer.config.mask_token_id,
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
LOG.info("=" * 60)
LOG.info("GENERATED SAMPLES")
LOG.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
LOG.info(f"\nSample {i}:")
LOG.info(f"\tOriginal ({total_tokens} tokens): {original}")
LOG.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
LOG.info(f"\tGenerated: {generated}")
LOG.info("=" * 60)
if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -0,0 +1,267 @@
"""Sample generation utilities for diffusion training."""
import logging
from typing import Any, List, Optional
import torch
logger = logging.getLogger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
val_dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences
from the validation dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
val_dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if val_dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples")
return []
# Get the actual model (unwrap if needed)
unwrapped_model = model.module if hasattr(model, "module") else model
unwrapped_model.eval()
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
val_dataloader, num_generation_samples, max_length, unwrapped_model.device
)
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for original_sequence in sampled_sequences:
generation_result = _generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
)
generations.append(generation_result)
unwrapped_model.train()
return generations
def _sample_sequences_from_dataloader(
val_dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader."""
sampled_sequences = []
sample_count = 0
# Add randomness by skipping a random number of batches
skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0
for batch in val_dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
# Limit sequence length to max_length
actual_length = min(seq_len, max_length)
if actual_length < 10: # Skip very short sequences
continue
# Extract the sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def _generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Apply custom masking with random ratio (10% to 70%)
total_tokens = original_sequence.size(1)
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
target_masked_tokens = int(total_tokens * target_mask_ratio)
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
# Clean up mask token representation
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model, sequence, step, num_diffusion_steps, temperature, mask_token_id
)
# Get final generated text
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
return {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
# Get the mask token representation from the tokenizer
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Clean up special tokens and whitespace
cleaned = cleaned.replace("<s>", "").replace("</s>", "").strip()
cleaned = " ".join(cleaned.split())
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create bidirectional attention mask for diffusion
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
# Sample predictions
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
# Deterministic sampling when temperature is 0
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -1,6 +1,7 @@
"""Diffusion LM training plugin for Axolotl.""" """Diffusion LM training plugin for Axolotl."""
from transformers import PreTrainedModel, Trainer from peft import PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -27,14 +28,14 @@ class DiffusionPlugin(BasePlugin):
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded.""" """Perform actions after model is loaded."""
self.cfg = cfg self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
"""Return custom trainer class for diffusion training.""" """Return custom trainer class for diffusion training."""
return DiffusionTrainer return DiffusionTrainer
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
"""Configure trainer after creation.""" """Configure trainer after creation."""
trainer.set_config(cfg) trainer.set_config(cfg)

View File

@@ -10,6 +10,8 @@ from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -18,14 +20,18 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._config = None self.config = None
self._special_token_ids = None self._special_token_ids = None
def set_config(self, config: DictDefault): def set_config(self, config: DictDefault):
"""Set config for diffusion training.""" """Set config for diffusion training."""
self._config = config self.config = config
self._cache_special_token_ids() self._cache_special_token_ids()
if config.generate_samples:
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def compute_loss( def compute_loss(
self, self,
model: nn.Module, model: nn.Module,
@@ -111,19 +117,19 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
for token_id in self._special_token_ids: for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id special_token_mask |= input_ids == token_id
# Create random mask based on p_mask # Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None: if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool() masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens # For SFT data, only mask answer tokens
if labels is not None: if labels is not None:
answer_mask = labels != -100 answer_mask = labels != -100
masked_indices = masked_indices & answer_mask masked_indices = masked_indices & answer_mask
# Create masked input # Create masked input
mask_token_id = self._config.mask_token_id mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask return noisy_batch, masked_indices, p_mask
@@ -147,7 +153,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
device = input_ids.device device = input_ids.device
if attention_mask is None or not self._config.sample_packing: if attention_mask is None or not self.config.sample_packing:
return torch.ones( return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
) )
@@ -186,7 +192,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
""" """
# Apply forward process # Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process( noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self._config.eps input_ids, attention_mask, labels, self.config.eps
) )
# Create bidirectional attention mask # Create bidirectional attention mask
@@ -214,7 +220,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
masked_logits.float(), masked_targets, reduction="none" masked_logits.float(), masked_targets, reduction="none"
) )
if self._config.importance_weighting: if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float() masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask weighted_loss = token_loss / masked_p_mask
else: else:
@@ -222,26 +228,28 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
# Final loss: sum weighted losses, normalize # Final loss: sum weighted losses, normalize
if labels is not None: if labels is not None:
# For SFT data: normalize by answer length per sample as per LLaDA guidelines # For SFT data: normalize by answer length per sample
answer_mask = labels != -100 answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens # Get batch indices for masked tokens
masked_batch_indices = batch_indices masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length # Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(input_ids.shape[0], device=input_ids.device) loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]): for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0: if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum() sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i] loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean() loss = loss_per_sample.mean()
else: else:
# Original normalization for non-SFT data # Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean() ce_loss = token_loss.mean()
# Compute accuracy on masked tokens # Compute accuracy on masked tokens
@@ -262,14 +270,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"avg_p_mask": p_mask[masked_indices].mean().item(), "avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(), "ce_loss": ce_loss.item(),
} }
# Add SFT-specific metrics # Add SFT-specific metrics
if labels is not None: if labels is not None:
answer_mask = labels != -100 answer_mask = labels != -100
metrics["answer_ratio"] = answer_mask.float().mean().item() metrics["answer_ratio"] = answer_mask.float().mean().item()
metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item() metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item()
if self._config.importance_weighting: if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval" train_eval: Literal["train", "eval"] = "train" if model.training else "eval"

View File

@@ -1,6 +1,4 @@
""" """E2E smoke test for diffusion training plugin."""
E2E smoke test for diffusion training plugin
"""
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
@@ -11,13 +9,12 @@ from tests.e2e.utils import check_model_output_exists
class TestDiffusion: class TestDiffusion:
""" """Test case for diffusion training plugin."""
Test case for diffusion training plugin
"""
def test_diffusion_smoke_test(self, temp_dir): def test_diffusion_smoke_test(self, temp_dir):
""" """
Smoke test for diffusion training to ensure the plugin loads and trains without error. Smoke test for diffusion training to ensure the plugin loads and trains without
error.
""" """
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -36,7 +33,7 @@ class TestDiffusion:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 3, # Very short for smoke test "max_steps": 3,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -48,33 +45,23 @@ class TestDiffusion:
"save_first_step": False, "save_first_step": False,
"logging_steps": 1, "logging_steps": 1,
"eval_steps": 3, "eval_steps": 3,
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
# Diffusion-specific config # Diffusion-specific config
"diffusion_mask_token_id": 32000, "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3, "diffusion_eps": 1e-3,
"diffusion_importance_weighting": False, "diffusion_importance_weighting": False,
} }
) )
# Normalize and validate config
cfg = normalize_config(cfg)
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# Load datasets to ensure they work with diffusion training train(cfg=cfg, dataset_meta=dataset_meta)
datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({})) check_model_output_exists(temp_dir, cfg)
assert datasets_meta.train_dataset is not None
assert len(datasets_meta.train_dataset) > 0
# Run training
train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta)
# Check that model was saved
check_model_output_exists(cfg)
def test_diffusion_sft_labels(self, temp_dir): def test_diffusion_sft_labels(self, temp_dir):
""" """Test that diffusion training properly handles SFT data with labels."""
Test that diffusion training properly handles SFT data with labels.
"""
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -92,7 +79,7 @@ class TestDiffusion:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, # Very short for smoke test "max_steps": 3,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -104,35 +91,29 @@ class TestDiffusion:
"save_first_step": False, "save_first_step": False,
"logging_steps": 1, "logging_steps": 1,
"eval_steps": 2, "eval_steps": 2,
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
# Diffusion-specific config # Diffusion-specific config
"diffusion_mask_token_id": 32000, "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion_mask_token_id": 16,
"diffusion_eps": 1e-3, "diffusion_eps": 1e-3,
"diffusion_importance_weighting": True, # Test importance weighting "diffusion_importance_weighting": True,
# Ensure we have proper SFT labels # Ensure we have proper SFT labels
"train_on_inputs": False, # This ensures prompt tokens get -100 labels "train_on_inputs": False,
} }
) )
# Normalize and validate config
cfg = normalize_config(cfg)
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# Load datasets
datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({}))
# Verify that the dataset has labels # Verify that the dataset has labels
sample = datasets_meta.train_dataset[0] sample = dataset_meta.train_dataset[0]
assert "labels" in sample, "SFT dataset should have labels" assert "labels" in sample, "SFT dataset should have labels"
# Check that some labels are -100 (prompt tokens) # Check that some labels are -100 (prompt tokens)
labels = sample["labels"] labels = sample["labels"]
if hasattr(labels, "tolist"): if hasattr(labels, "tolist"):
labels = labels.tolist() labels = labels.tolist()
assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
# Run training train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta) check_model_output_exists(temp_dir, cfg)
# Check that model was saved
check_model_output_exists(cfg)

View File

@@ -1,8 +1,11 @@
"""Tests for diffusion trainer integration.""" """Tests for diffusion trainer integration."""
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock
import pytest import pytest
import torch import torch
from unittest.mock import Mock
from axolotl.integrations.diffusion.trainer import DiffusionTrainer from axolotl.integrations.diffusion.trainer import DiffusionTrainer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -21,113 +24,122 @@ def mock_tokenizer():
@pytest.fixture @pytest.fixture
def diffusion_config(): def diffusion_config():
"""Create a diffusion config.""" """Create a diffusion config."""
return DictDefault({ return DictDefault(
"mask_token_id": 32000, {
"eps": 1e-3, "mask_token_id": 32000,
"importance_weighting": False, "eps": 1e-3,
"sample_packing": False, "importance_weighting": False,
}) "sample_packing": False,
}
)
@pytest.fixture @pytest.fixture
def diffusion_trainer(mock_tokenizer, diffusion_config): def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance.""" """Create a diffusion trainer instance for testing methods directly."""
# Create a mock model to satisfy Trainer's requirements # Create a minimal trainer instance just for testing methods
mock_model = Mock() trainer = object.__new__(DiffusionTrainer) # Bypass __init__
mock_model.training = True trainer.config = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer = DiffusionTrainer(model=mock_model)
trainer.processing_class = mock_tokenizer trainer.processing_class = mock_tokenizer
trainer.set_config(diffusion_config) trainer.store_metrics = Mock() # Mock metrics storage
return trainer return trainer
class TestDiffusionTrainer: class TestDiffusionTrainer:
"""Test the DiffusionTrainer class.""" """Test the DiffusionTrainer class."""
def test_forward_process_basic(self, diffusion_trainer): def test_forward_process_basic(self, diffusion_trainer_instance):
"""Test basic forward process without labels.""" """Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( noisy_batch, masked_indices, p_mask = (
input_ids, eps=0.1 diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
) )
# Check shapes # Check shapes
assert noisy_batch.shape == input_ids.shape assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked # Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any() assert not masked_indices[special_token_positions].any()
# Check that mask token is applied # Check that mask token is applied
mask_token_id = diffusion_trainer._config.mask_token_id mask_token_id = diffusion_trainer_instance._config.mask_token_id
masked_positions = masked_indices masked_positions = masked_indices
if masked_positions.any(): if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all() assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer): def test_forward_process_with_labels(self, diffusion_trainer_instance):
"""Test forward process with SFT labels.""" """Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( noisy_batch, masked_indices, p_mask = (
input_ids, labels=labels, eps=0.1 diffusion_trainer_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
) )
# Check shapes # Check shapes
assert noisy_batch.shape == input_ids.shape assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100) # Check that only answer tokens can be masked (where labels != -100)
answer_mask = labels != -100
non_answer_mask = labels == -100 non_answer_mask = labels == -100
# No masking should occur on non-answer tokens # No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any() assert not masked_indices[non_answer_mask].any()
# Check that probabilities are zero for non-answer tokens
assert (p_mask[non_answer_mask] == 0).all()
def test_forward_process_with_attention_mask(self, diffusion_trainer): # p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
"""Test forward process with attention mask.""" """Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( _, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1 input_ids, attention_mask=attention_mask, eps=0.1
) )
# Check that padding tokens are not masked # Check that padding tokens are not masked
padding_positions = attention_mask == 0 padding_positions = attention_mask == 0
assert not masked_indices[padding_positions].any() assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all() assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer): def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
"""Test bidirectional attention mask without sample packing.""" """Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_trainer._create_bidirectional_attention_mask(input_ids) mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids
)
# Should be all-to-all attention # Should be all-to-all attention
expected_shape = (1, 1, 4, 4) expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape assert mask.shape == expected_shape
assert mask.all() assert mask.all()
def test_bidirectional_attention_mask_with_packing(self, diffusion_trainer): def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance
):
"""Test bidirectional attention mask with sample packing.""" """Test bidirectional attention mask with sample packing."""
diffusion_trainer._config.sample_packing = True diffusion_trainer_instance._config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2) # Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_trainer._create_bidirectional_attention_mask( mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids, attention_mask input_ids, attention_mask
) )
# Check that tokens within same sample can attend to each other # Check that tokens within same sample can attend to each other
# but not across samples # but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
@@ -136,65 +148,59 @@ class TestDiffusionTrainer:
assert not mask[0, 0, 2, 4].item() assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer): def test_compute_loss_basic(self, diffusion_trainer_instance):
"""Test basic loss computation.""" """Test basic loss computation."""
# Mock model that returns logits # Mock model that returns logits
mock_model = Mock() mock_model = Mock()
mock_outputs = Mock() mock_outputs = Mock()
vocab_size = 1000 vocab_size = 1000
seq_len = 5 seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size) mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Mock the store_metrics method loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
diffusion_trainer.store_metrics = Mock()
loss, outputs = diffusion_trainer._compute_diffusion_loss(
mock_model, input_ids mock_model, input_ids
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
assert outputs == mock_outputs assert outputs == mock_outputs
# Check that metrics were stored
diffusion_trainer.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer): # Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels.""" """Test loss computation with SFT labels."""
# Mock model # Mock model
mock_model = Mock() mock_model = Mock()
mock_outputs = Mock() mock_outputs = Mock()
vocab_size = 1000 vocab_size = 1000
seq_len = 5 seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size) mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
# Mock the store_metrics method loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
diffusion_trainer.store_metrics = Mock()
loss, outputs = diffusion_trainer._compute_diffusion_loss(
mock_model, input_ids, labels=labels mock_model, input_ids, labels=labels
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
# Check that SFT metrics were added # Check that SFT metrics were added
call_args = diffusion_trainer.store_metrics.call_args[0][0] call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer): def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked.""" """Test loss computation when no tokens are masked."""
# Mock model # Mock model
mock_model = Mock() mock_model = Mock()
@@ -204,38 +210,33 @@ class TestDiffusionTrainer:
mock_outputs.logits = torch.randn(1, seq_len, vocab_size) mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
# Only special tokens (which won't be masked) # Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
# Mock the store_metrics method loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
diffusion_trainer.store_metrics = Mock()
loss, outputs = diffusion_trainer._compute_diffusion_loss(
mock_model, input_ids mock_model, input_ids
) )
# Loss should be zero when no tokens are masked # Loss should be zero when no tokens are masked
assert loss.item() == 0.0 assert loss.item() == 0.0
assert loss.requires_grad assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_trainer, mock_tokenizer): def test_cache_special_token_ids(self, diffusion_trainer_instance):
"""Test caching of special token IDs.""" """Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens # Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_trainer._special_token_ids == expected_tokens assert diffusion_trainer_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self): def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available.""" """Test caching when no tokenizer is available."""
# Create a mock model to satisfy Trainer's requirements trainer = object.__new__(DiffusionTrainer) # Bypass __init__
mock_model = Mock()
trainer = DiffusionTrainer(model=mock_model)
trainer.processing_class = None trainer.processing_class = None
trainer._cache_special_token_ids() trainer._cache_special_token_ids()
assert trainer._special_token_ids == set() assert trainer._special_token_ids == set()
def test_main_compute_loss_interface(self, diffusion_trainer): def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test the main compute_loss interface.""" """Test the main compute_loss interface."""
# Mock model # Mock model
mock_model = Mock() mock_model = Mock()
@@ -243,31 +244,28 @@ class TestDiffusionTrainer:
mock_outputs.logits = torch.randn(1, 5, 1000) mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
# Mock the store_metrics method
diffusion_trainer.store_metrics = Mock()
inputs = { inputs = {
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
} }
# Test without return_outputs # Test without return_outputs
loss = diffusion_trainer.compute_loss(mock_model, inputs) loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
# Test with return_outputs # Test with return_outputs
loss, outputs = diffusion_trainer.compute_loss( loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True mock_model, inputs, return_outputs=True
) )
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer): def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError.""" """Test that missing input_ids raises ValueError."""
mock_model = Mock() mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"): with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer.compute_loss(mock_model, inputs) diffusion_trainer_instance.compute_loss(mock_model, inputs)