diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 94ee8d4b1..4bf7581d6 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -147,7 +147,7 @@ class BasePlugin: """ # pylint: disable=unused-argument - def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: """Returns a custom class for the trainer. Args: diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md index 7a1e909a6..b4176dd60 100644 --- a/src/axolotl/integrations/diffusion/README.md +++ b/src/axolotl/integrations/diffusion/README.md @@ -27,15 +27,24 @@ Add the following to your Axolotl configuration YAML: ```yaml # Enable diffusion LM training plugin plugins: - - diffusion.DiffusionPlugin + - axolotl.integrations.diffusion.DiffusionPlugin # Diffusion-specific configuration -noise_schedule: "linear" # or "cosine" +noise_schedule: linear # or "cosine" min_mask_ratio: 0.1 max_mask_ratio: 0.9 num_diffusion_steps: 128 eps: 1e-3 importance_weighting: true +mask_token_id: 128002 + +# Sample generation (optional) +generate_samples: true +generation_interval: 100 +num_generation_samples: 3 +generation_steps: 128 +generation_temperature: 0.0 +generation_max_length: 100 # Model configuration base_model: meta-llama/Llama-3.2-1B @@ -88,24 +97,37 @@ loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens - Consider using gradient checkpointing, torch.compile, ### Training Stability -- Start with `noise_schedule: "linear"` for more predictable behavior -- Enable `importance_weighting` for better gradient scaling +- Start with `noise_schedule: linear` for more predictable behavior +- Enable `importance_weighting: true` for better gradient scaling ### Convergence - Monitor the `diffusion_loss` and `diffusion_accuracy` metrics - Expect different loss curves compared to standard language modeling +## Sample Generation + +When `generate_samples: true`, the plugin generates samples during training: + +``` +📝 Sample 1: + Original (45 tokens): The quick brown fox jumps over the lazy dog... + Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]... + Generated: The quick brown fox jumps over the lazy dog... +``` + +Samples are logged to console and wandb (if enabled). + ## Metrics and Monitoring The plugin adds several metrics to track diffusion training: -- `train/diffusion_loss`: Weighted diffusion loss -- `train/diffusion_accuracy`: Accuracy on masked tokens -- `train/diffusion_mask_ratio`: Average fraction of tokens masked -- `train/diffusion_num_masked_tokens`: Number of tokens masked -- `train/diffusion_avg_p_mask`: Average masking probability -- `train/diffusion_ce_loss`: Unweighted cross-entropy loss -- `train/diffusion_importance_weight_avg`: Average importance weight +- `train/loss`: Weighted diffusion loss +- `train/accuracy`: Accuracy on masked tokens +- `train/mask_ratio`: Average fraction of tokens masked +- `train/num_masked_tokens`: Number of tokens masked +- `train/avg_p_mask`: Average masking probability +- `train/ce_loss`: Unweighted cross-entropy loss +- `train/importance_weight_avg`: Average importance weight ## Limitations diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py index 5dbb08a3d..1f6263e47 100644 --- a/src/axolotl/integrations/diffusion/args.py +++ b/src/axolotl/integrations/diffusion/args.py @@ -46,5 +46,27 @@ class DiffusionArgs(BaseModel): description=( "Token ID to use for masking. Default is 128002 " "(<|reserved_special_token_0|> for Llama 3.2)" - ) + ), + ) + + # Sample generation config + generate_samples: bool = Field( + default=True, description="Enable sample generation during training" + ) + generation_interval: int = Field( + default=100, ge=1, description="Generate samples every N steps" + ) + num_generation_samples: int = Field( + default=3, ge=1, description="Number of samples to generate each time" + ) + generation_steps: int = Field( + default=128, ge=1, description="Number of diffusion steps for generation" + ) + generation_temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for generation sampling (0.0 = deterministic)", + ) + generation_max_length: int = Field( + default=100, ge=1, description="Maximum sequence length for generation" ) diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py new file mode 100644 index 000000000..fafd99919 --- /dev/null +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -0,0 +1,115 @@ +"""Callbacks for diffusion training.""" + +import wandb +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from axolotl.utils.logging import get_logger + +from .generation import generate_samples + +LOG = get_logger(__name__) + + +class DiffusionGenerationCallback(TrainerCallback): + """Callback for generating samples during diffusion training.""" + + def __init__(self, trainer): + self.trainer = trainer + + # pylint: disable=unused-argument + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + # Only generate samples at the specified interval and after step 0 + if ( + state.global_step > 0 + and state.global_step % self.trainer.config.generation_interval == 0 + and hasattr(self.trainer, "eval_dataset") + and self.trainer.eval_dataset is not None + ): + + LOG.info( + f"Generating {self.trainer.config.num_generation_samples} samples at step {state.global_step}..." + ) + + # Create a simple dataloader from eval dataset for sampling + eval_dataloader = self.trainer.get_eval_dataloader() + + # Generate samples + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.tokenizer, + val_dataloader=eval_dataloader, + num_generation_samples=self.trainer.config.num_generation_samples, + max_length=self.trainer.config.generation_max_length, + num_diffusion_steps=self.trainer.config.generation_steps, + temperature=self.trainer.config.generation_temperature, + mask_token_id=self.trainer.config.mask_token_id, + ) + + # Log samples + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples.""" + if not samples: + return + + LOG.info("=" * 60) + LOG.info("GENERATED SAMPLES") + LOG.info("=" * 60) + + for i, sample_data in enumerate(samples, 1): + original = sample_data["original"] + masked = sample_data["masked"] + generated = sample_data["generated"] + mask_ratio = sample_data["mask_ratio"] + masked_tokens = sample_data["masked_tokens"] + total_tokens = sample_data["total_tokens"] + + LOG.info(f"\nSample {i}:") + LOG.info(f"\tOriginal ({total_tokens} tokens): {original}") + LOG.info( + f"\tMasked ({masked_tokens}/{total_tokens} tokens, " + f"{mask_ratio:.1%}): {masked}" + ) + LOG.info(f"\tGenerated: {generated}") + + LOG.info("=" * 60) + + if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero: + if wandb.run is not None: + wandb.log( + { + "generated_samples": wandb.Table( + columns=[ + "step", + "original", + "masked", + "generated", + "mask_ratio", + "masked_tokens", + "total_tokens", + ], + data=[ + [ + step, + sample["original"], + sample["masked"], + sample["generated"], + f"{sample['mask_ratio']:.1%}", + sample["masked_tokens"], + sample["total_tokens"], + ] + for sample in samples + ], + ) + }, + step=step, + ) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py new file mode 100644 index 000000000..d019ec023 --- /dev/null +++ b/src/axolotl/integrations/diffusion/generation.py @@ -0,0 +1,267 @@ +"""Sample generation utilities for diffusion training.""" + +import logging +from typing import Any, List, Optional + +import torch + +logger = logging.getLogger(__name__) + + +def generate_samples( + model: torch.nn.Module, + tokenizer: Any, + val_dataloader: Optional[Any] = None, + num_generation_samples: int = 3, + max_length: int = 100, + num_diffusion_steps: int = 128, + temperature: float = 0.0, + mask_token_id: int = 32000, +) -> List[dict]: + """ + Generate text samples using the diffusion model by randomly masking sequences + from the validation dataset and running the reverse diffusion process. + + Args: + model: The wrapped or unwrapped model + tokenizer: Tokenizer for encoding/decoding + val_dataloader: Validation dataloader (for sampling sequences) + num_generation_samples: Number of samples to generate + max_length: Maximum length of sequences to use + num_diffusion_steps: Number of diffusion steps for generation + temperature: Temperature for sampling (0.0 = deterministic) + mask_token_id: Token ID used for masking + + Returns: + List of dictionaries with original text, masked text, and generated text + """ + if val_dataloader is None: + logger.warning("No validation dataloader provided, cannot generate samples") + return [] + + # Get the actual model (unwrap if needed) + unwrapped_model = model.module if hasattr(model, "module") else model + unwrapped_model.eval() + generations = [] + + # Sample sequences from validation dataset + sampled_sequences = _sample_sequences_from_dataloader( + val_dataloader, num_generation_samples, max_length, unwrapped_model.device + ) + logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") + + # Generate samples using reverse diffusion process + with torch.no_grad(): + for original_sequence in sampled_sequences: + generation_result = _generate( + unwrapped_model, + tokenizer, + original_sequence, + num_diffusion_steps, + temperature, + mask_token_id, + ) + generations.append(generation_result) + + unwrapped_model.train() + return generations + + +def _sample_sequences_from_dataloader( + val_dataloader: Any, num_samples: int, max_length: int, device: torch.device +) -> List[torch.Tensor]: + """Sample sequences from validation dataloader.""" + sampled_sequences = [] + sample_count = 0 + + # Add randomness by skipping a random number of batches + skip_batches = torch.randint(0, 6, (1,)).item() + batch_count = 0 + + for batch in val_dataloader: + # Skip some batches for variety + if batch_count < skip_batches: + batch_count += 1 + continue + + if sample_count >= num_samples: + break + + batch_count += 1 + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask") + + # Randomly sample from sequences in this batch + batch_indices = torch.randperm(input_ids.size(0)).tolist() + + for i in batch_indices: + if sample_count >= num_samples: + break + + # Get actual sequence length (non-padded) + if attention_mask is not None: + seq_len = attention_mask[i].sum().item() + else: + seq_len = input_ids.size(1) + + # Limit sequence length to max_length + actual_length = min(seq_len, max_length) + if actual_length < 10: # Skip very short sequences + continue + + # Extract the sequence + sequence = input_ids[i][:actual_length].unsqueeze(0).to(device) + sampled_sequences.append(sequence) + sample_count += 1 + + return sampled_sequences + + +def _generate( + model: torch.nn.Module, + tokenizer: Any, + original_sequence: torch.Tensor, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, +) -> dict: + """Generate a single sample using reverse diffusion.""" + # Get original text for comparison + original_text = tokenizer.decode( + original_sequence[0].cpu(), skip_special_tokens=True + ) + + # Apply custom masking with random ratio (10% to 70%) + total_tokens = original_sequence.size(1) + min_ratio, max_ratio = 0.1, 0.7 + target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio + target_masked_tokens = int(total_tokens * target_mask_ratio) + + # Create random mask indices + mask_positions = torch.randperm(total_tokens)[:target_masked_tokens] + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, mask_positions] = True + + # Create masked sequence + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + + # Calculate actual mask ratio + masked_tokens = masked_indices.sum().item() + mask_ratio = masked_tokens / total_tokens + + # Get masked text for comparison + masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False) + # Clean up mask token representation + masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id) + + # Run reverse diffusion process + sequence = masked_sequence.clone() + for step in range(num_diffusion_steps): + sequence = _diffusion_step( + model, sequence, step, num_diffusion_steps, temperature, mask_token_id + ) + + # Get final generated text + generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True) + + return { + "original": original_text, + "masked": masked_text, + "generated": generated_text, + "mask_ratio": mask_ratio, + "masked_tokens": masked_tokens, + "total_tokens": total_tokens, + "formatted": ( + f"Original: '{original_text}' → Masked: '{masked_text}' " + f"({mask_ratio:.1%}) → Generated: '{generated_text}'" + ), + } + + +def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: + """Clean up masked text for display.""" + # Get the mask token representation from the tokenizer + mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) + cleaned = masked_text.replace(mask_token_repr, "[MASK]") + + # Clean up special tokens and whitespace + cleaned = cleaned.replace("", "").replace("", "").strip() + cleaned = " ".join(cleaned.split()) + + return cleaned + + +def _diffusion_step( + model: torch.nn.Module, + sequence: torch.Tensor, + step: int, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, +) -> torch.Tensor: + """Perform a single diffusion step with remasking.""" + # Only process if there are masked tokens remaining + current_mask = sequence == mask_token_id + if not current_mask.any(): + return sequence + + # Create bidirectional attention mask for diffusion + batch_size, seq_len = sequence.shape + attention_mask = torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device + ) + + # Forward pass + outputs = model(input_ids=sequence, attention_mask=attention_mask) + logits = outputs.logits + + # Only sample at currently masked positions + if current_mask.any(): + masked_logits = logits[current_mask] + + # Apply temperature scaling + if temperature > 0: + scaled_logits = masked_logits / temperature + else: + scaled_logits = masked_logits + + # Suppress mask token in outputs + scaled_logits[:, mask_token_id] = -float("inf") + + # Sample predictions + if temperature > 0: + # Add Gumbel noise for sampling + gumbel_noise = -torch.log( + -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32)) + ) + gumbel_logits = scaled_logits + gumbel_noise + predicted_tokens = torch.argmax(gumbel_logits, dim=-1) + else: + # Deterministic sampling when temperature is 0 + predicted_tokens = torch.argmax(scaled_logits, dim=-1) + + # Calculate probabilities for confidence scoring + probs = torch.softmax(scaled_logits, dim=-1) + predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens] + + # Determine how many tokens to unmask this step + remaining_masked = current_mask.sum().item() + if step == num_diffusion_steps - 1: + num_to_unmask = remaining_masked + else: + unmask_ratio = 1.0 / (num_diffusion_steps - step) + num_to_unmask = max(1, int(remaining_masked * unmask_ratio)) + + # Select highest confidence predictions to unmask + if num_to_unmask >= remaining_masked: + sequence[current_mask] = predicted_tokens + else: + _, top_indices = predicted_token_probs.topk(num_to_unmask) + mask_positions = torch.where(current_mask)[1] + positions_to_unmask = mask_positions[top_indices] + sequence[0, positions_to_unmask] = predicted_tokens[top_indices] + + return sequence diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py index 76eea006d..c31f48b03 100644 --- a/src/axolotl/integrations/diffusion/plugin.py +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -1,6 +1,7 @@ """Diffusion LM training plugin for Axolotl.""" -from transformers import PreTrainedModel, Trainer +from peft import PeftModel +from transformers import PreTrainedModel from axolotl.integrations.base import BasePlugin from axolotl.utils.dict import DictDefault @@ -27,14 +28,14 @@ class DiffusionPlugin(BasePlugin): """Returns the pydantic model for LLaDA plugin arguments.""" return "axolotl.integrations.diffusion.DiffusionArgs" - def post_model_load(self, cfg: DictDefault, model: PreTrainedModel): + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Perform actions after model is loaded.""" self.cfg = cfg - def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: """Return custom trainer class for diffusion training.""" return DiffusionTrainer - def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): """Configure trainer after creation.""" trainer.set_config(cfg) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 7b3181f32..be1bb9838 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -10,6 +10,8 @@ from axolotl.core.trainers.base import AxolotlTrainer from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger +from .callbacks import DiffusionGenerationCallback + LOG = get_logger(__name__) @@ -18,14 +20,18 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._config = None + self.config = None self._special_token_ids = None def set_config(self, config: DictDefault): """Set config for diffusion training.""" - self._config = config + self.config = config self._cache_special_token_ids() + if config.generate_samples: + generation_callback = DiffusionGenerationCallback(self) + self.add_callback(generation_callback) + def compute_loss( self, model: nn.Module, @@ -111,19 +117,19 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors for token_id in self._special_token_ids: special_token_mask |= input_ids == token_id - # Create random mask based on p_mask + # Create random mask based on p_mask masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask masked_indices = masked_indices & ~special_token_mask if attention_mask is not None: masked_indices = masked_indices & attention_mask.bool() - + # For SFT data, only mask answer tokens if labels is not None: answer_mask = labels != -100 masked_indices = masked_indices & answer_mask # Create masked input - mask_token_id = self._config.mask_token_id + mask_token_id = self.config.mask_token_id noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) return noisy_batch, masked_indices, p_mask @@ -147,7 +153,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors batch_size, seq_len = input_ids.shape device = input_ids.device - if attention_mask is None or not self._config.sample_packing: + if attention_mask is None or not self.config.sample_packing: return torch.ones( batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device ) @@ -186,7 +192,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors """ # Apply forward process noisy_batch, masked_indices, p_mask = self._forward_process( - input_ids, attention_mask, labels, self._config.eps + input_ids, attention_mask, labels, self.config.eps ) # Create bidirectional attention mask @@ -214,7 +220,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors masked_logits.float(), masked_targets, reduction="none" ) - if self._config.importance_weighting: + if self.config.importance_weighting: masked_p_mask = masked_p_mask.float() weighted_loss = token_loss / masked_p_mask else: @@ -222,26 +228,28 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors # Final loss: sum weighted losses, normalize if labels is not None: - # For SFT data: normalize by answer length per sample as per LLaDA guidelines + # For SFT data: normalize by answer length per sample answer_mask = labels != -100 answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] - + # Get batch indices for masked tokens masked_batch_indices = batch_indices - + # Sum losses per sample and divide by answer length - loss_per_sample = torch.zeros(input_ids.shape[0], device=input_ids.device) + loss_per_sample = torch.zeros( + input_ids.shape[0], device=input_ids.device + ) for i in range(input_ids.shape[0]): sample_mask = masked_batch_indices == i if sample_mask.sum() > 0: sample_loss = weighted_loss[sample_mask].sum() loss_per_sample[i] = sample_loss / answer_lengths[i] - + loss = loss_per_sample.mean() else: # Original normalization for non-SFT data loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) - + ce_loss = token_loss.mean() # Compute accuracy on masked tokens @@ -262,14 +270,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors "avg_p_mask": p_mask[masked_indices].mean().item(), "ce_loss": ce_loss.item(), } - + # Add SFT-specific metrics if labels is not None: answer_mask = labels != -100 metrics["answer_ratio"] = answer_mask.float().mean().item() metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item() - if self._config.importance_weighting: + if self.config.importance_weighting: metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() train_eval: Literal["train", "eval"] = "train" if model.training else "eval" diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4f8735d77..df88cb730 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1,6 +1,4 @@ -""" -E2E smoke test for diffusion training plugin -""" +"""E2E smoke test for diffusion training plugin.""" from axolotl.common.datasets import load_datasets from axolotl.train import train @@ -11,13 +9,12 @@ from tests.e2e.utils import check_model_output_exists class TestDiffusion: - """ - Test case for diffusion training plugin - """ + """Test case for diffusion training plugin.""" def test_diffusion_smoke_test(self, temp_dir): """ - Smoke test for diffusion training to ensure the plugin loads and trains without error. + Smoke test for diffusion training to ensure the plugin loads and trains without + error. """ cfg = DictDefault( { @@ -36,7 +33,7 @@ class TestDiffusion: }, ], "num_epochs": 1, - "max_steps": 3, # Very short for smoke test + "max_steps": 3, "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, @@ -48,33 +45,23 @@ class TestDiffusion: "save_first_step": False, "logging_steps": 1, "eval_steps": 3, - "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], # Diffusion-specific config - "diffusion_mask_token_id": 32000, + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion_mask_token_id": 16, "diffusion_eps": 1e-3, "diffusion_importance_weighting": False, } ) - # Normalize and validate config - cfg = normalize_config(cfg) cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) - # Load datasets to ensure they work with diffusion training - datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({})) - assert datasets_meta.train_dataset is not None - assert len(datasets_meta.train_dataset) > 0 - - # Run training - train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta) - - # Check that model was saved - check_model_output_exists(cfg) + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) def test_diffusion_sft_labels(self, temp_dir): - """ - Test that diffusion training properly handles SFT data with labels. - """ + """Test that diffusion training properly handles SFT data with labels.""" cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -92,7 +79,7 @@ class TestDiffusion: }, ], "num_epochs": 1, - "max_steps": 2, # Very short for smoke test + "max_steps": 3, "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, @@ -104,35 +91,29 @@ class TestDiffusion: "save_first_step": False, "logging_steps": 1, "eval_steps": 2, - "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], # Diffusion-specific config - "diffusion_mask_token_id": 32000, + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion_mask_token_id": 16, "diffusion_eps": 1e-3, - "diffusion_importance_weighting": True, # Test importance weighting + "diffusion_importance_weighting": True, # Ensure we have proper SFT labels - "train_on_inputs": False, # This ensures prompt tokens get -100 labels + "train_on_inputs": False, } ) - # Normalize and validate config - cfg = normalize_config(cfg) cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) - # Load datasets - datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({})) - # Verify that the dataset has labels - sample = datasets_meta.train_dataset[0] + sample = dataset_meta.train_dataset[0] assert "labels" in sample, "SFT dataset should have labels" - + # Check that some labels are -100 (prompt tokens) labels = sample["labels"] if hasattr(labels, "tolist"): labels = labels.tolist() assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" - # Run training - train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta) - - # Check that model was saved - check_model_output_exists(cfg) \ No newline at end of file + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py index 6f316059e..583597238 100644 --- a/tests/integrations/test_diffusion.py +++ b/tests/integrations/test_diffusion.py @@ -1,8 +1,11 @@ """Tests for diffusion trainer integration.""" +# pylint: disable=redefined-outer-name,protected-access + +from unittest.mock import Mock + import pytest import torch -from unittest.mock import Mock from axolotl.integrations.diffusion.trainer import DiffusionTrainer from axolotl.utils.dict import DictDefault @@ -21,113 +24,122 @@ def mock_tokenizer(): @pytest.fixture def diffusion_config(): """Create a diffusion config.""" - return DictDefault({ - "mask_token_id": 32000, - "eps": 1e-3, - "importance_weighting": False, - "sample_packing": False, - }) + return DictDefault( + { + "mask_token_id": 32000, + "eps": 1e-3, + "importance_weighting": False, + "sample_packing": False, + } + ) @pytest.fixture -def diffusion_trainer(mock_tokenizer, diffusion_config): - """Create a diffusion trainer instance.""" - # Create a mock model to satisfy Trainer's requirements - mock_model = Mock() - mock_model.training = True - - trainer = DiffusionTrainer(model=mock_model) +def diffusion_trainer_instance(mock_tokenizer, diffusion_config): + """Create a diffusion trainer instance for testing methods directly.""" + # Create a minimal trainer instance just for testing methods + trainer = object.__new__(DiffusionTrainer) # Bypass __init__ + trainer.config = diffusion_config + trainer._special_token_ids = {0, 1, 2} # pad, bos, eos trainer.processing_class = mock_tokenizer - trainer.set_config(diffusion_config) + trainer.store_metrics = Mock() # Mock metrics storage return trainer class TestDiffusionTrainer: """Test the DiffusionTrainer class.""" - def test_forward_process_basic(self, diffusion_trainer): + def test_forward_process_basic(self, diffusion_trainer_instance): """Test basic forward process without labels.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) - - noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( - input_ids, eps=0.1 + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process(input_ids, eps=0.1) ) - + # Check shapes assert noisy_batch.shape == input_ids.shape assert masked_indices.shape == input_ids.shape assert p_mask.shape == input_ids.shape - + # Check that special tokens are not masked special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) assert not masked_indices[special_token_positions].any() - + # Check that mask token is applied - mask_token_id = diffusion_trainer._config.mask_token_id + mask_token_id = diffusion_trainer_instance._config.mask_token_id masked_positions = masked_indices if masked_positions.any(): assert (noisy_batch[masked_positions] == mask_token_id).all() - def test_forward_process_with_labels(self, diffusion_trainer): + def test_forward_process_with_labels(self, diffusion_trainer_instance): """Test forward process with SFT labels.""" input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) - - noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( - input_ids, labels=labels, eps=0.1 + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process( + input_ids, labels=labels, eps=0.1 + ) ) - + # Check shapes assert noisy_batch.shape == input_ids.shape assert masked_indices.shape == input_ids.shape assert p_mask.shape == input_ids.shape - + # Check that only answer tokens can be masked (where labels != -100) - answer_mask = labels != -100 non_answer_mask = labels == -100 - + # No masking should occur on non-answer tokens assert not masked_indices[non_answer_mask].any() - - # Check that probabilities are zero for non-answer tokens - assert (p_mask[non_answer_mask] == 0).all() - def test_forward_process_with_attention_mask(self, diffusion_trainer): + # p_mask should be the same for all positions (sampled timestep), + # but masking is only applied to answer tokens + assert p_mask.shape == input_ids.shape + # Verify that masked_indices respects the answer mask + assert not masked_indices[non_answer_mask].any() + + def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): """Test forward process with attention mask.""" input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) - - noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( + + _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( input_ids, attention_mask=attention_mask, eps=0.1 ) - + # Check that padding tokens are not masked padding_positions = attention_mask == 0 assert not masked_indices[padding_positions].any() assert (p_mask[padding_positions] == 0).all() - def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer): + def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): """Test bidirectional attention mask without sample packing.""" input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) - - mask = diffusion_trainer._create_bidirectional_attention_mask(input_ids) - + + mask = diffusion_trainer_instance._create_bidirectional_attention_mask( + input_ids + ) + # Should be all-to-all attention expected_shape = (1, 1, 4, 4) assert mask.shape == expected_shape assert mask.all() - def test_bidirectional_attention_mask_with_packing(self, diffusion_trainer): + def test_bidirectional_attention_mask_with_packing( + self, diffusion_trainer_instance + ): """Test bidirectional attention mask with sample packing.""" - diffusion_trainer._config.sample_packing = True + diffusion_trainer_instance._config.sample_packing = True input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) # Sample IDs: first sample (1), second sample (2) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) - - mask = diffusion_trainer._create_bidirectional_attention_mask( + + mask = diffusion_trainer_instance._create_bidirectional_attention_mask( input_ids, attention_mask ) - + # Check that tokens within same sample can attend to each other # but not across samples assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other @@ -136,65 +148,59 @@ class TestDiffusionTrainer: assert not mask[0, 0, 2, 4].item() assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other - def test_compute_loss_basic(self, diffusion_trainer): + def test_compute_loss_basic(self, diffusion_trainer_instance): """Test basic loss computation.""" # Mock model that returns logits mock_model = Mock() mock_outputs = Mock() vocab_size = 1000 seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) mock_model.return_value = mock_outputs mock_model.training = True - + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) - - # Mock the store_metrics method - diffusion_trainer.store_metrics = Mock() - - loss, outputs = diffusion_trainer._compute_diffusion_loss( + + loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( mock_model, input_ids ) - + # Check that loss is computed assert isinstance(loss, torch.Tensor) assert loss.requires_grad assert outputs == mock_outputs - - # Check that metrics were stored - diffusion_trainer.store_metrics.assert_called_once() - def test_compute_loss_with_labels(self, diffusion_trainer): + # Check that metrics were stored + diffusion_trainer_instance.store_metrics.assert_called_once() + + def test_compute_loss_with_labels(self, diffusion_trainer_instance): """Test loss computation with SFT labels.""" # Mock model mock_model = Mock() mock_outputs = Mock() vocab_size = 1000 seq_len = 5 - mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) mock_model.return_value = mock_outputs mock_model.training = True - + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) - - # Mock the store_metrics method - diffusion_trainer.store_metrics = Mock() - - loss, outputs = diffusion_trainer._compute_diffusion_loss( + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( mock_model, input_ids, labels=labels ) - + # Check that loss is computed assert isinstance(loss, torch.Tensor) assert loss.requires_grad - + # Check that SFT metrics were added - call_args = diffusion_trainer.store_metrics.call_args[0][0] + call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] assert "answer_ratio" in call_args assert "avg_answer_length" in call_args - def test_compute_loss_no_masked_tokens(self, diffusion_trainer): + def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): """Test loss computation when no tokens are masked.""" # Mock model mock_model = Mock() @@ -204,38 +210,33 @@ class TestDiffusionTrainer: mock_outputs.logits = torch.randn(1, seq_len, vocab_size) mock_model.return_value = mock_outputs mock_model.training = True - + # Only special tokens (which won't be masked) input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) - - # Mock the store_metrics method - diffusion_trainer.store_metrics = Mock() - - loss, outputs = diffusion_trainer._compute_diffusion_loss( + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( mock_model, input_ids ) - + # Loss should be zero when no tokens are masked assert loss.item() == 0.0 assert loss.requires_grad - def test_cache_special_token_ids(self, diffusion_trainer, mock_tokenizer): + def test_cache_special_token_ids(self, diffusion_trainer_instance): """Test caching of special token IDs.""" # Should cache BOS, EOS, PAD tokens expected_tokens = {0, 1, 2} # pad, bos, eos - assert diffusion_trainer._special_token_ids == expected_tokens + assert diffusion_trainer_instance._special_token_ids == expected_tokens def test_cache_special_token_ids_no_tokenizer(self): """Test caching when no tokenizer is available.""" - # Create a mock model to satisfy Trainer's requirements - mock_model = Mock() - trainer = DiffusionTrainer(model=mock_model) + trainer = object.__new__(DiffusionTrainer) # Bypass __init__ trainer.processing_class = None trainer._cache_special_token_ids() - + assert trainer._special_token_ids == set() - def test_main_compute_loss_interface(self, diffusion_trainer): + def test_main_compute_loss_interface(self, diffusion_trainer_instance): """Test the main compute_loss interface.""" # Mock model mock_model = Mock() @@ -243,31 +244,28 @@ class TestDiffusionTrainer: mock_outputs.logits = torch.randn(1, 5, 1000) mock_model.return_value = mock_outputs mock_model.training = True - - # Mock the store_metrics method - diffusion_trainer.store_metrics = Mock() - + inputs = { "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), } - + # Test without return_outputs - loss = diffusion_trainer.compute_loss(mock_model, inputs) + loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) assert isinstance(loss, torch.Tensor) - + # Test with return_outputs - loss, outputs = diffusion_trainer.compute_loss( + loss, outputs = diffusion_trainer_instance.compute_loss( mock_model, inputs, return_outputs=True ) assert isinstance(loss, torch.Tensor) assert outputs == mock_outputs - def test_missing_input_ids_raises_error(self, diffusion_trainer): + def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): """Test that missing input_ids raises ValueError.""" mock_model = Mock() inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} - + with pytest.raises(ValueError, match="input_ids is required"): - diffusion_trainer.compute_loss(mock_model, inputs) \ No newline at end of file + diffusion_trainer_instance.compute_loss(mock_model, inputs)