From 077b5a43581c775631927c4c4922000ec592e768 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 16 Aug 2025 02:44:44 +0000 Subject: [PATCH] cleanup; tests draft --- .../llama-3/diffusion-3.2-1b-pretrain.yaml | 1 + examples/llama-3/diffusion-3.2-1b-sft.yaml | 1 + src/axolotl/integrations/diffusion/args.py | 5 +- src/axolotl/integrations/diffusion/trainer.py | 76 +++-- tests/e2e/test_diffusion.py | 138 +++++++++ tests/integrations/test_diffusion.py | 273 ++++++++++++++++++ 6 files changed, 469 insertions(+), 25 deletions(-) create mode 100644 tests/e2e/test_diffusion.py create mode 100644 tests/integrations/test_diffusion.py diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml index ca0271ba7..1e17a0ea1 100644 --- a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -18,6 +18,7 @@ max_mask_ratio: 0.85 num_diffusion_steps: 128 eps: 5e-4 importance_weighting: true +mask_token_id: 128002 output_dir: ./outputs/model-out diff --git a/examples/llama-3/diffusion-3.2-1b-sft.yaml b/examples/llama-3/diffusion-3.2-1b-sft.yaml index 019fefbb3..af00ac9fd 100644 --- a/examples/llama-3/diffusion-3.2-1b-sft.yaml +++ b/examples/llama-3/diffusion-3.2-1b-sft.yaml @@ -15,6 +15,7 @@ max_mask_ratio: 0.9 num_diffusion_steps: 128 eps: 1e-3 importance_weighting: true +mask_token_id: 128002 output_dir: ./outputs/model-out diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py index 0e27e7362..5dbb08a3d 100644 --- a/src/axolotl/integrations/diffusion/args.py +++ b/src/axolotl/integrations/diffusion/args.py @@ -43,5 +43,8 @@ class DiffusionArgs(BaseModel): ) mask_token_id: int = Field( default=128002, - description="Token ID to use for masking. Default is 128002 (<|reserved_special_token_0|> for Llama 3.2)", + description=( + "Token ID to use for masking. Default is 128002 " + "(<|reserved_special_token_0|> for Llama 3.2)" + ) ) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 9bf000b6d..7b3181f32 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -36,15 +36,17 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors """Override compute_loss to use diffusion loss.""" input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") + labels = inputs.get("labels") if input_ids is None: raise ValueError("input_ids is required for diffusion training") - loss, outputs = self._compute_diffusion_loss(model, input_ids, attention_mask) + loss, outputs = self._compute_diffusion_loss( + model, input_ids, attention_mask, labels + ) if return_outputs: return loss, outputs - return loss def _cache_special_token_ids(self): @@ -70,6 +72,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, eps: float = 1e-3, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -79,6 +82,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors Args: input_ids: Input token ids [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. eps: Small epsilon value for minimum masking probability. Returns: @@ -101,22 +105,25 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors valid_mask = attention_mask.bool() p_mask = p_mask * valid_mask.float() - # Create mask to exclude special tokens (BOS, EOS, PAD) using cached IDs + # Create mask to exclude special tokens special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) if self._special_token_ids: for token_id in self._special_token_ids: special_token_mask |= input_ids == token_id - # Create random mask based on probability, excluding special tokens + # Create random mask based on p_mask masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask masked_indices = masked_indices & ~special_token_mask if attention_mask is not None: masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask - # Get mask token ID from config + # Create masked input mask_token_id = self._config.mask_token_id - - # Create masked input using configured mask token noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) return noisy_batch, masked_indices, p_mask @@ -126,9 +133,9 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None ) -> torch.Tensor: """ - Create bidirectional attention mask to override default causal masking. - Handles sample-packed sequences where different samples are identified - by different attention mask values. + Create bidirectional attention mask to override default causal masking. Handles + sample-packed sequences where different samples are identified by different + attention mask values. Args: input_ids: Input token ids [batch_size, seq_len]. @@ -141,7 +148,6 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors device = input_ids.device if attention_mask is None or not self._config.sample_packing: - # Simple case: no attention mask, allow all-to-all attention return torch.ones( batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device ) @@ -163,6 +169,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | Any]: """ Compute diffusion loss. @@ -171,6 +178,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors model: The model to compute loss for. input_ids: Ground truth token ids [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. Returns: loss: Cross-entropy loss. @@ -178,7 +186,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors """ # Apply forward process noisy_batch, masked_indices, p_mask = self._forward_process( - input_ids, attention_mask, self._config.eps + input_ids, attention_mask, labels, self._config.eps ) # Create bidirectional attention mask @@ -197,29 +205,43 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors valid_indices = torch.where(masked_indices) batch_indices, seq_indices = valid_indices - # Extract the relevant data - masked_logits = logits[ - batch_indices, seq_indices - ] # [num_masked_tokens, vocab_size] - masked_targets = input_ids[ - batch_indices, seq_indices - ] # [num_masked_tokens] - masked_p_mask = p_mask[batch_indices, seq_indices] # [num_masked_tokens] + masked_logits = logits[batch_indices, seq_indices] + masked_targets = input_ids[batch_indices, seq_indices] + masked_p_mask = p_mask[batch_indices, seq_indices] - # Compute cross-entropy loss without reduction (cast to fp32 for stability) + # Compute cross-entropy loss without reduction token_loss = F.cross_entropy( masked_logits.float(), masked_targets, reduction="none" ) - # Apply importance weighting if enabled if self._config.importance_weighting: masked_p_mask = masked_p_mask.float() weighted_loss = token_loss / masked_p_mask else: weighted_loss = token_loss - # Final loss: sum weighted losses, normalize by total tokens - loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) + # Final loss: sum weighted losses, normalize + if labels is not None: + # For SFT data: normalize by answer length per sample as per LLaDA guidelines + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] + + # Get batch indices for masked tokens + masked_batch_indices = batch_indices + + # Sum losses per sample and divide by answer length + loss_per_sample = torch.zeros(input_ids.shape[0], device=input_ids.device) + for i in range(input_ids.shape[0]): + sample_mask = masked_batch_indices == i + if sample_mask.sum() > 0: + sample_loss = weighted_loss[sample_mask].sum() + loss_per_sample[i] = sample_loss / answer_lengths[i] + + loss = loss_per_sample.mean() + else: + # Original normalization for non-SFT data + loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) + ce_loss = token_loss.mean() # Compute accuracy on masked tokens @@ -240,6 +262,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors "avg_p_mask": p_mask[masked_indices].mean().item(), "ce_loss": ce_loss.item(), } + + # Add SFT-specific metrics + if labels is not None: + answer_mask = labels != -100 + metrics["answer_ratio"] = answer_mask.float().mean().item() + metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item() if self._config.importance_weighting: metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py new file mode 100644 index 000000000..4f8735d77 --- /dev/null +++ b/tests/e2e/test_diffusion.py @@ -0,0 +1,138 @@ +""" +E2E smoke test for diffusion training plugin +""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +class TestDiffusion: + """ + Test case for diffusion training plugin + """ + + def test_diffusion_smoke_test(self, temp_dir): + """ + Smoke test for diffusion training to ensure the plugin loads and trains without error. + """ + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short for smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 3, + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + # Diffusion-specific config + "diffusion_mask_token_id": 32000, + "diffusion_eps": 1e-3, + "diffusion_importance_weighting": False, + } + ) + + # Normalize and validate config + cfg = normalize_config(cfg) + cfg = validate_config(cfg) + + # Load datasets to ensure they work with diffusion training + datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({})) + assert datasets_meta.train_dataset is not None + assert len(datasets_meta.train_dataset) > 0 + + # Run training + train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta) + + # Check that model was saved + check_model_output_exists(cfg) + + def test_diffusion_sft_labels(self, temp_dir): + """ + Test that diffusion training properly handles SFT data with labels. + """ + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 2, # Very short for smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 2, + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + # Diffusion-specific config + "diffusion_mask_token_id": 32000, + "diffusion_eps": 1e-3, + "diffusion_importance_weighting": True, # Test importance weighting + # Ensure we have proper SFT labels + "train_on_inputs": False, # This ensures prompt tokens get -100 labels + } + ) + + # Normalize and validate config + cfg = normalize_config(cfg) + cfg = validate_config(cfg) + + # Load datasets + datasets_meta = load_datasets(cfg=cfg, cli_args=DictDefault({})) + + # Verify that the dataset has labels + sample = datasets_meta.train_dataset[0] + assert "labels" in sample, "SFT dataset should have labels" + + # Check that some labels are -100 (prompt tokens) + labels = sample["labels"] + if hasattr(labels, "tolist"): + labels = labels.tolist() + assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" + + # Run training + train(cfg=cfg, cli_args=DictDefault({}), dataset_meta=datasets_meta) + + # Check that model was saved + check_model_output_exists(cfg) \ No newline at end of file diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py new file mode 100644 index 000000000..6f316059e --- /dev/null +++ b/tests/integrations/test_diffusion.py @@ -0,0 +1,273 @@ +"""Tests for diffusion trainer integration.""" + +import pytest +import torch +from unittest.mock import Mock + +from axolotl.integrations.diffusion.trainer import DiffusionTrainer +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + tokenizer.pad_token_id = 0 + return tokenizer + + +@pytest.fixture +def diffusion_config(): + """Create a diffusion config.""" + return DictDefault({ + "mask_token_id": 32000, + "eps": 1e-3, + "importance_weighting": False, + "sample_packing": False, + }) + + +@pytest.fixture +def diffusion_trainer(mock_tokenizer, diffusion_config): + """Create a diffusion trainer instance.""" + # Create a mock model to satisfy Trainer's requirements + mock_model = Mock() + mock_model.training = True + + trainer = DiffusionTrainer(model=mock_model) + trainer.processing_class = mock_tokenizer + trainer.set_config(diffusion_config) + return trainer + + +class TestDiffusionTrainer: + """Test the DiffusionTrainer class.""" + + def test_forward_process_basic(self, diffusion_trainer): + """Test basic forward process without labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( + input_ids, eps=0.1 + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that special tokens are not masked + special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) + assert not masked_indices[special_token_positions].any() + + # Check that mask token is applied + mask_token_id = diffusion_trainer._config.mask_token_id + masked_positions = masked_indices + if masked_positions.any(): + assert (noisy_batch[masked_positions] == mask_token_id).all() + + def test_forward_process_with_labels(self, diffusion_trainer): + """Test forward process with SFT labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( + input_ids, labels=labels, eps=0.1 + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that only answer tokens can be masked (where labels != -100) + answer_mask = labels != -100 + non_answer_mask = labels == -100 + + # No masking should occur on non-answer tokens + assert not masked_indices[non_answer_mask].any() + + # Check that probabilities are zero for non-answer tokens + assert (p_mask[non_answer_mask] == 0).all() + + def test_forward_process_with_attention_mask(self, diffusion_trainer): + """Test forward process with attention mask.""" + input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = diffusion_trainer._forward_process( + input_ids, attention_mask=attention_mask, eps=0.1 + ) + + # Check that padding tokens are not masked + padding_positions = attention_mask == 0 + assert not masked_indices[padding_positions].any() + assert (p_mask[padding_positions] == 0).all() + + def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer): + """Test bidirectional attention mask without sample packing.""" + input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) + + mask = diffusion_trainer._create_bidirectional_attention_mask(input_ids) + + # Should be all-to-all attention + expected_shape = (1, 1, 4, 4) + assert mask.shape == expected_shape + assert mask.all() + + def test_bidirectional_attention_mask_with_packing(self, diffusion_trainer): + """Test bidirectional attention mask with sample packing.""" + diffusion_trainer._config.sample_packing = True + input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) + # Sample IDs: first sample (1), second sample (2) + attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) + + mask = diffusion_trainer._create_bidirectional_attention_mask( + input_ids, attention_mask + ) + + # Check that tokens within same sample can attend to each other + # but not across samples + assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other + assert mask[0, 0, 1, 2].item() + assert not mask[0, 0, 0, 3].item() # Can't attend across samples + assert not mask[0, 0, 2, 4].item() + assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other + + def test_compute_loss_basic(self, diffusion_trainer): + """Test basic loss computation.""" + # Mock model that returns logits + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + # Mock the store_metrics method + diffusion_trainer.store_metrics = Mock() + + loss, outputs = diffusion_trainer._compute_diffusion_loss( + mock_model, input_ids + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + assert outputs == mock_outputs + + # Check that metrics were stored + diffusion_trainer.store_metrics.assert_called_once() + + def test_compute_loss_with_labels(self, diffusion_trainer): + """Test loss computation with SFT labels.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + # Mock the store_metrics method + diffusion_trainer.store_metrics = Mock() + + loss, outputs = diffusion_trainer._compute_diffusion_loss( + mock_model, input_ids, labels=labels + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + + # Check that SFT metrics were added + call_args = diffusion_trainer.store_metrics.call_args[0][0] + assert "answer_ratio" in call_args + assert "avg_answer_length" in call_args + + def test_compute_loss_no_masked_tokens(self, diffusion_trainer): + """Test loss computation when no tokens are masked.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 3 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + # Only special tokens (which won't be masked) + input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + + # Mock the store_metrics method + diffusion_trainer.store_metrics = Mock() + + loss, outputs = diffusion_trainer._compute_diffusion_loss( + mock_model, input_ids + ) + + # Loss should be zero when no tokens are masked + assert loss.item() == 0.0 + assert loss.requires_grad + + def test_cache_special_token_ids(self, diffusion_trainer, mock_tokenizer): + """Test caching of special token IDs.""" + # Should cache BOS, EOS, PAD tokens + expected_tokens = {0, 1, 2} # pad, bos, eos + assert diffusion_trainer._special_token_ids == expected_tokens + + def test_cache_special_token_ids_no_tokenizer(self): + """Test caching when no tokenizer is available.""" + # Create a mock model to satisfy Trainer's requirements + mock_model = Mock() + trainer = DiffusionTrainer(model=mock_model) + trainer.processing_class = None + trainer._cache_special_token_ids() + + assert trainer._special_token_ids == set() + + def test_main_compute_loss_interface(self, diffusion_trainer): + """Test the main compute_loss interface.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + mock_outputs.logits = torch.randn(1, 5, 1000) + mock_model.return_value = mock_outputs + mock_model.training = True + + # Mock the store_metrics method + diffusion_trainer.store_metrics = Mock() + + inputs = { + "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), + "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), + } + + # Test without return_outputs + loss = diffusion_trainer.compute_loss(mock_model, inputs) + assert isinstance(loss, torch.Tensor) + + # Test with return_outputs + loss, outputs = diffusion_trainer.compute_loss( + mock_model, inputs, return_outputs=True + ) + assert isinstance(loss, torch.Tensor) + assert outputs == mock_outputs + + def test_missing_input_ids_raises_error(self, diffusion_trainer): + """Test that missing input_ids raises ValueError.""" + mock_model = Mock() + inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} + + with pytest.raises(ValueError, match="input_ids is required"): + diffusion_trainer.compute_loss(mock_model, inputs) \ No newline at end of file