diffusion custom models approach

This commit is contained in:
Dan Saunders
2025-08-19 04:09:46 +00:00
parent 63d2280999
commit 1f75287a3a
8 changed files with 779 additions and 423 deletions

View File

@@ -1,13 +1,14 @@
"""Tests for diffusion trainer integration."""
"""Tests for diffusion model integration."""
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
import torch
from axolotl.integrations.diffusion.trainer import DiffusionTrainer
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
from axolotl.utils.dict import DictDefault
@@ -24,37 +25,44 @@ def mock_tokenizer():
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return DictDefault(
{
"mask_token_id": 32000,
"eps": 1e-3,
"importance_weighting": False,
"sample_packing": False,
}
return LlamaForDiffusionConfig(
mask_token_id=32000,
eps=1e-3,
importance_weighting=False,
sample_packing=False,
# Basic llama config fields - smaller for testing
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
)
@pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.config = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
return trainer
def diffusion_model_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion model instance for testing methods directly."""
# Create a minimal model instance for testing
model = object.__new__(LlamaForDiffusionLM)
model.config = diffusion_config
model._special_token_ids = {0, 1, 2} # pad, bos, eos
model.training = True
# Set tokenizer
model.set_tokenizer(mock_tokenizer)
return model
class TestDiffusionTrainer:
"""Test the DiffusionTrainer class."""
class TestDiffusionModel:
"""Test the DiffusionModel class."""
def test_forward_process_basic(self, diffusion_trainer_instance):
def test_forward_process_basic(self, diffusion_model_instance):
"""Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
diffusion_model_instance._forward_process(input_ids, eps=0.1)
)
# Check shapes
@@ -67,18 +75,18 @@ class TestDiffusionTrainer:
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied
mask_token_id = diffusion_trainer_instance._config.mask_token_id
mask_token_id = diffusion_model_instance.config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance):
def test_forward_process_with_labels(self, diffusion_model_instance):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(
diffusion_model_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
)
@@ -100,12 +108,12 @@ class TestDiffusionTrainer:
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
_, masked_indices, p_mask = diffusion_model_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
)
@@ -114,11 +122,11 @@ class TestDiffusionTrainer:
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
"""Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids
)
@@ -128,15 +136,15 @@ class TestDiffusionTrainer:
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance
self, diffusion_model_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance._config.sample_packing = True
diffusion_model_instance.config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
mask = diffusion_model_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
@@ -148,124 +156,135 @@ class TestDiffusionTrainer:
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer_instance):
def test_compute_loss_basic(self, diffusion_model_instance):
"""Test basic loss computation."""
# Mock model that returns logits
mock_model = Mock()
mock_outputs = Mock()
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create a simple masked indices tensor (mask middle tokens)
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert outputs == mock_outputs
# Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
def test_compute_loss_with_labels(self, diffusion_model_instance):
"""Test loss computation with SFT labels."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create masked indices that only covers answer tokens
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids, labels=labels
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
labels=labels,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
# Check that SFT metrics were added
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
"""Test loss computation when no tokens are masked."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
logits = torch.randn(1, seq_len, vocab_size)
# No tokens masked
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
# Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
loss = diffusion_model_instance._compute_diffusion_loss(
input_ids=input_ids,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
)
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_trainer_instance):
def test_cache_special_token_ids(self, diffusion_model_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_trainer_instance._special_token_ids == expected_tokens
assert diffusion_model_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.processing_class = None
trainer._cache_special_token_ids()
# Mock the parent model initialization to avoid loading pretrained weights
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
assert trainer._special_token_ids == set()
def test_forward_training_mode(self, diffusion_model_instance):
"""Test forward pass in training mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_output.logits = torch.randn(1, 5, 32000)
mock_forward.return_value = mock_output
# Set training mode
diffusion_model_instance.training = True
result = diffusion_model_instance.forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Should call parent forward and compute loss
assert mock_forward.called
assert hasattr(result, 'loss')
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test the main compute_loss interface."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = {
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
}
# Test without return_outputs
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
assert isinstance(loss, torch.Tensor)
# Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError."""
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs)
def test_forward_inference_mode(self, diffusion_model_instance):
"""Test forward pass in inference mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_forward.return_value = mock_output
# Set inference mode
diffusion_model_instance.training = False
result = diffusion_model_instance.forward(
input_ids=input_ids,
return_dict=True
)
# Should just call parent forward without diffusion processing
assert mock_forward.called
assert result == mock_output