291 lines
11 KiB
Python
291 lines
11 KiB
Python
"""Tests for diffusion model integration."""
|
|
|
|
# pylint: disable=redefined-outer-name,protected-access
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
|
|
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tokenizer():
|
|
"""Create a mock tokenizer."""
|
|
tokenizer = Mock()
|
|
tokenizer.bos_token_id = 1
|
|
tokenizer.eos_token_id = 2
|
|
tokenizer.pad_token_id = 0
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
def diffusion_config():
|
|
"""Create a diffusion config."""
|
|
return LlamaForDiffusionConfig(
|
|
mask_token_id=32000,
|
|
eps=1e-3,
|
|
importance_weighting=False,
|
|
sample_packing=False,
|
|
# Basic llama config fields - smaller for testing
|
|
vocab_size=1000,
|
|
hidden_size=256,
|
|
intermediate_size=512,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def diffusion_model_instance(mock_tokenizer, diffusion_config):
|
|
"""Create a diffusion model instance for testing methods directly."""
|
|
# Create a minimal model instance for testing
|
|
model = object.__new__(LlamaForDiffusionLM)
|
|
model.config = diffusion_config
|
|
model._special_token_ids = {0, 1, 2} # pad, bos, eos
|
|
model.training = True
|
|
|
|
# Set tokenizer
|
|
model.set_tokenizer(mock_tokenizer)
|
|
|
|
return model
|
|
|
|
|
|
class TestDiffusionModel:
|
|
"""Test the DiffusionModel class."""
|
|
|
|
def test_forward_process_basic(self, diffusion_model_instance):
|
|
"""Test basic forward process without labels."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
|
|
noisy_batch, masked_indices, p_mask = (
|
|
diffusion_model_instance._forward_process(input_ids, eps=0.1)
|
|
)
|
|
|
|
# Check shapes
|
|
assert noisy_batch.shape == input_ids.shape
|
|
assert masked_indices.shape == input_ids.shape
|
|
assert p_mask.shape == input_ids.shape
|
|
|
|
# Check that special tokens are not masked
|
|
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
|
|
assert not masked_indices[special_token_positions].any()
|
|
|
|
# Check that mask token is applied
|
|
mask_token_id = diffusion_model_instance.config.mask_token_id
|
|
masked_positions = masked_indices
|
|
if masked_positions.any():
|
|
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
|
|
|
def test_forward_process_with_labels(self, diffusion_model_instance):
|
|
"""Test forward process with SFT labels."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
|
|
|
noisy_batch, masked_indices, p_mask = (
|
|
diffusion_model_instance._forward_process(
|
|
input_ids, labels=labels, eps=0.1
|
|
)
|
|
)
|
|
|
|
# Check shapes
|
|
assert noisy_batch.shape == input_ids.shape
|
|
assert masked_indices.shape == input_ids.shape
|
|
assert p_mask.shape == input_ids.shape
|
|
|
|
# Check that only answer tokens can be masked (where labels != -100)
|
|
non_answer_mask = labels == -100
|
|
|
|
# No masking should occur on non-answer tokens
|
|
assert not masked_indices[non_answer_mask].any()
|
|
|
|
# p_mask should be the same for all positions (sampled timestep),
|
|
# but masking is only applied to answer tokens
|
|
assert p_mask.shape == input_ids.shape
|
|
# Verify that masked_indices respects the answer mask
|
|
assert not masked_indices[non_answer_mask].any()
|
|
|
|
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
|
|
"""Test forward process with attention mask."""
|
|
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
|
|
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
|
|
|
|
_, masked_indices, p_mask = diffusion_model_instance._forward_process(
|
|
input_ids, attention_mask=attention_mask, eps=0.1
|
|
)
|
|
|
|
# Check that padding tokens are not masked
|
|
padding_positions = attention_mask == 0
|
|
assert not masked_indices[padding_positions].any()
|
|
assert (p_mask[padding_positions] == 0).all()
|
|
|
|
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
|
|
"""Test bidirectional attention mask without sample packing."""
|
|
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
|
|
|
|
mask = diffusion_model_instance._create_bidirectional_attention_mask(
|
|
input_ids
|
|
)
|
|
|
|
# Should be all-to-all attention
|
|
expected_shape = (1, 1, 4, 4)
|
|
assert mask.shape == expected_shape
|
|
assert mask.all()
|
|
|
|
def test_bidirectional_attention_mask_with_packing(
|
|
self, diffusion_model_instance
|
|
):
|
|
"""Test bidirectional attention mask with sample packing."""
|
|
diffusion_model_instance.config.sample_packing = True
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
|
# Sample IDs: first sample (1), second sample (2)
|
|
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
|
|
|
mask = diffusion_model_instance._create_bidirectional_attention_mask(
|
|
input_ids, attention_mask
|
|
)
|
|
|
|
# Check that tokens within same sample can attend to each other
|
|
# but not across samples
|
|
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
|
|
assert mask[0, 0, 1, 2].item()
|
|
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
|
|
assert not mask[0, 0, 2, 4].item()
|
|
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
|
|
|
|
def test_compute_loss_basic(self, diffusion_model_instance):
|
|
"""Test basic loss computation."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
|
|
# Create mock data for loss computation
|
|
vocab_size = 1000
|
|
seq_len = 5
|
|
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
|
|
|
# Create a simple masked indices tensor (mask middle tokens)
|
|
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
|
|
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
|
|
|
|
loss = diffusion_model_instance._compute_diffusion_loss(
|
|
input_ids=input_ids,
|
|
logits=logits,
|
|
masked_indices=masked_indices,
|
|
p_mask=p_mask,
|
|
)
|
|
|
|
# Check that loss is computed
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert loss.requires_grad
|
|
|
|
def test_compute_loss_with_labels(self, diffusion_model_instance):
|
|
"""Test loss computation with SFT labels."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
|
|
|
# Create mock data for loss computation
|
|
vocab_size = 1000
|
|
seq_len = 5
|
|
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
|
|
|
# Create masked indices that only covers answer tokens
|
|
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
|
|
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
|
|
|
|
loss = diffusion_model_instance._compute_diffusion_loss(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
logits=logits,
|
|
masked_indices=masked_indices,
|
|
p_mask=p_mask,
|
|
)
|
|
|
|
# Check that loss is computed
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert loss.requires_grad
|
|
|
|
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
|
|
"""Test loss computation when no tokens are masked."""
|
|
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
|
|
|
|
# Create mock data for loss computation
|
|
vocab_size = 1000
|
|
seq_len = 3
|
|
logits = torch.randn(1, seq_len, vocab_size)
|
|
|
|
# No tokens masked
|
|
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
|
|
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
|
|
|
|
loss = diffusion_model_instance._compute_diffusion_loss(
|
|
input_ids=input_ids,
|
|
logits=logits,
|
|
masked_indices=masked_indices,
|
|
p_mask=p_mask,
|
|
)
|
|
|
|
# Loss should be zero when no tokens are masked
|
|
assert loss.item() == 0.0
|
|
assert loss.requires_grad
|
|
|
|
def test_cache_special_token_ids(self, diffusion_model_instance):
|
|
"""Test caching of special token IDs."""
|
|
# Should cache BOS, EOS, PAD tokens
|
|
expected_tokens = {0, 1, 2} # pad, bos, eos
|
|
assert diffusion_model_instance._special_token_ids == expected_tokens
|
|
|
|
def test_cache_special_token_ids_no_tokenizer(self):
|
|
"""Test caching when no tokenizer is available."""
|
|
# Mock the parent model initialization to avoid loading pretrained weights
|
|
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
|
|
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
|
|
model._cache_special_token_ids(None)
|
|
assert model._special_token_ids == set()
|
|
|
|
def test_forward_training_mode(self, diffusion_model_instance):
|
|
"""Test forward pass in training mode."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
|
|
|
|
# Mock the parent forward method
|
|
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
|
|
mock_output = Mock()
|
|
mock_output.logits = torch.randn(1, 5, 32000)
|
|
mock_forward.return_value = mock_output
|
|
|
|
# Set training mode
|
|
diffusion_model_instance.training = True
|
|
|
|
result = diffusion_model_instance.forward(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
return_dict=True
|
|
)
|
|
|
|
# Should call parent forward and compute loss
|
|
assert mock_forward.called
|
|
assert hasattr(result, 'loss')
|
|
|
|
def test_forward_inference_mode(self, diffusion_model_instance):
|
|
"""Test forward pass in inference mode."""
|
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
|
|
|
# Mock the parent forward method
|
|
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
|
|
mock_output = Mock()
|
|
mock_forward.return_value = mock_output
|
|
|
|
# Set inference mode
|
|
diffusion_model_instance.training = False
|
|
|
|
result = diffusion_model_instance.forward(
|
|
input_ids=input_ids,
|
|
return_dict=True
|
|
)
|
|
|
|
# Should just call parent forward without diffusion processing
|
|
assert mock_forward.called
|
|
assert result == mock_output
|