diffusion custom models approach
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
"""Tests for diffusion trainer integration."""
|
||||
"""Tests for diffusion model integration."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.diffusion.trainer import DiffusionTrainer
|
||||
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig
|
||||
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -24,37 +25,44 @@ def mock_tokenizer():
|
||||
@pytest.fixture
|
||||
def diffusion_config():
|
||||
"""Create a diffusion config."""
|
||||
return DictDefault(
|
||||
{
|
||||
"mask_token_id": 32000,
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": False,
|
||||
"sample_packing": False,
|
||||
}
|
||||
return LlamaForDiffusionConfig(
|
||||
mask_token_id=32000,
|
||||
eps=1e-3,
|
||||
importance_weighting=False,
|
||||
sample_packing=False,
|
||||
# Basic llama config fields - smaller for testing
|
||||
vocab_size=1000,
|
||||
hidden_size=256,
|
||||
intermediate_size=512,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
|
||||
"""Create a diffusion trainer instance for testing methods directly."""
|
||||
# Create a minimal trainer instance just for testing methods
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.config = diffusion_config
|
||||
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||
trainer.processing_class = mock_tokenizer
|
||||
trainer.store_metrics = Mock() # Mock metrics storage
|
||||
return trainer
|
||||
def diffusion_model_instance(mock_tokenizer, diffusion_config):
|
||||
"""Create a diffusion model instance for testing methods directly."""
|
||||
# Create a minimal model instance for testing
|
||||
model = object.__new__(LlamaForDiffusionLM)
|
||||
model.config = diffusion_config
|
||||
model._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||
model.training = True
|
||||
|
||||
# Set tokenizer
|
||||
model.set_tokenizer(mock_tokenizer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TestDiffusionTrainer:
|
||||
"""Test the DiffusionTrainer class."""
|
||||
class TestDiffusionModel:
|
||||
"""Test the DiffusionModel class."""
|
||||
|
||||
def test_forward_process_basic(self, diffusion_trainer_instance):
|
||||
def test_forward_process_basic(self, diffusion_model_instance):
|
||||
"""Test basic forward process without labels."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
|
||||
diffusion_model_instance._forward_process(input_ids, eps=0.1)
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
@@ -67,18 +75,18 @@ class TestDiffusionTrainer:
|
||||
assert not masked_indices[special_token_positions].any()
|
||||
|
||||
# Check that mask token is applied
|
||||
mask_token_id = diffusion_trainer_instance._config.mask_token_id
|
||||
mask_token_id = diffusion_model_instance.config.mask_token_id
|
||||
masked_positions = masked_indices
|
||||
if masked_positions.any():
|
||||
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
||||
|
||||
def test_forward_process_with_labels(self, diffusion_trainer_instance):
|
||||
def test_forward_process_with_labels(self, diffusion_model_instance):
|
||||
"""Test forward process with SFT labels."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(
|
||||
diffusion_model_instance._forward_process(
|
||||
input_ids, labels=labels, eps=0.1
|
||||
)
|
||||
)
|
||||
@@ -100,12 +108,12 @@ class TestDiffusionTrainer:
|
||||
# Verify that masked_indices respects the answer mask
|
||||
assert not masked_indices[non_answer_mask].any()
|
||||
|
||||
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
|
||||
def test_forward_process_with_attention_mask(self, diffusion_model_instance):
|
||||
"""Test forward process with attention mask."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
|
||||
|
||||
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
|
||||
_, masked_indices, p_mask = diffusion_model_instance._forward_process(
|
||||
input_ids, attention_mask=attention_mask, eps=0.1
|
||||
)
|
||||
|
||||
@@ -114,11 +122,11 @@ class TestDiffusionTrainer:
|
||||
assert not masked_indices[padding_positions].any()
|
||||
assert (p_mask[padding_positions] == 0).all()
|
||||
|
||||
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
|
||||
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance):
|
||||
"""Test bidirectional attention mask without sample packing."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
|
||||
|
||||
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
|
||||
mask = diffusion_model_instance._create_bidirectional_attention_mask(
|
||||
input_ids
|
||||
)
|
||||
|
||||
@@ -128,15 +136,15 @@ class TestDiffusionTrainer:
|
||||
assert mask.all()
|
||||
|
||||
def test_bidirectional_attention_mask_with_packing(
|
||||
self, diffusion_trainer_instance
|
||||
self, diffusion_model_instance
|
||||
):
|
||||
"""Test bidirectional attention mask with sample packing."""
|
||||
diffusion_trainer_instance._config.sample_packing = True
|
||||
diffusion_model_instance.config.sample_packing = True
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
||||
# Sample IDs: first sample (1), second sample (2)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
||||
|
||||
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
|
||||
mask = diffusion_model_instance._create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
|
||||
@@ -148,124 +156,135 @@ class TestDiffusionTrainer:
|
||||
assert not mask[0, 0, 2, 4].item()
|
||||
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
|
||||
|
||||
def test_compute_loss_basic(self, diffusion_trainer_instance):
|
||||
def test_compute_loss_basic(self, diffusion_model_instance):
|
||||
"""Test basic loss computation."""
|
||||
# Mock model that returns logits
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
# Create mock data for loss computation
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
|
||||
# Create a simple masked indices tensor (mask middle tokens)
|
||||
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
|
||||
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
loss = diffusion_model_instance._compute_diffusion_loss(
|
||||
input_ids=input_ids,
|
||||
logits=logits,
|
||||
masked_indices=masked_indices,
|
||||
p_mask=p_mask,
|
||||
)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
assert outputs == mock_outputs
|
||||
|
||||
# Check that metrics were stored
|
||||
diffusion_trainer_instance.store_metrics.assert_called_once()
|
||||
|
||||
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
|
||||
def test_compute_loss_with_labels(self, diffusion_model_instance):
|
||||
"""Test loss computation with SFT labels."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
# Create mock data for loss computation
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
|
||||
# Create masked indices that only covers answer tokens
|
||||
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
|
||||
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids, labels=labels
|
||||
loss = diffusion_model_instance._compute_diffusion_loss(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
logits=logits,
|
||||
masked_indices=masked_indices,
|
||||
p_mask=p_mask,
|
||||
)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
|
||||
# Check that SFT metrics were added
|
||||
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
|
||||
assert "answer_ratio" in call_args
|
||||
assert "avg_answer_length" in call_args
|
||||
|
||||
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
|
||||
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance):
|
||||
"""Test loss computation when no tokens are masked."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
|
||||
|
||||
# Create mock data for loss computation
|
||||
vocab_size = 1000
|
||||
seq_len = 3
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
logits = torch.randn(1, seq_len, vocab_size)
|
||||
|
||||
# No tokens masked
|
||||
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
|
||||
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
|
||||
|
||||
# Only special tokens (which won't be masked)
|
||||
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
loss = diffusion_model_instance._compute_diffusion_loss(
|
||||
input_ids=input_ids,
|
||||
logits=logits,
|
||||
masked_indices=masked_indices,
|
||||
p_mask=p_mask,
|
||||
)
|
||||
|
||||
# Loss should be zero when no tokens are masked
|
||||
assert loss.item() == 0.0
|
||||
assert loss.requires_grad
|
||||
|
||||
def test_cache_special_token_ids(self, diffusion_trainer_instance):
|
||||
def test_cache_special_token_ids(self, diffusion_model_instance):
|
||||
"""Test caching of special token IDs."""
|
||||
# Should cache BOS, EOS, PAD tokens
|
||||
expected_tokens = {0, 1, 2} # pad, bos, eos
|
||||
assert diffusion_trainer_instance._special_token_ids == expected_tokens
|
||||
assert diffusion_model_instance._special_token_ids == expected_tokens
|
||||
|
||||
def test_cache_special_token_ids_no_tokenizer(self):
|
||||
"""Test caching when no tokenizer is available."""
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.processing_class = None
|
||||
trainer._cache_special_token_ids()
|
||||
# Mock the parent model initialization to avoid loading pretrained weights
|
||||
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'):
|
||||
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM)
|
||||
model._cache_special_token_ids(None)
|
||||
assert model._special_token_ids == set()
|
||||
|
||||
assert trainer._special_token_ids == set()
|
||||
def test_forward_training_mode(self, diffusion_model_instance):
|
||||
"""Test forward pass in training mode."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
|
||||
|
||||
# Mock the parent forward method
|
||||
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
|
||||
mock_output = Mock()
|
||||
mock_output.logits = torch.randn(1, 5, 32000)
|
||||
mock_forward.return_value = mock_output
|
||||
|
||||
# Set training mode
|
||||
diffusion_model_instance.training = True
|
||||
|
||||
result = diffusion_model_instance.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Should call parent forward and compute loss
|
||||
assert mock_forward.called
|
||||
assert hasattr(result, 'loss')
|
||||
|
||||
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
|
||||
"""Test the main compute_loss interface."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
mock_outputs.logits = torch.randn(1, 5, 1000)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
|
||||
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
|
||||
}
|
||||
|
||||
# Test without return_outputs
|
||||
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test with return_outputs
|
||||
loss, outputs = diffusion_trainer_instance.compute_loss(
|
||||
mock_model, inputs, return_outputs=True
|
||||
)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert outputs == mock_outputs
|
||||
|
||||
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
|
||||
"""Test that missing input_ids raises ValueError."""
|
||||
mock_model = Mock()
|
||||
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
|
||||
|
||||
with pytest.raises(ValueError, match="input_ids is required"):
|
||||
diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
def test_forward_inference_mode(self, diffusion_model_instance):
|
||||
"""Test forward pass in inference mode."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
# Mock the parent forward method
|
||||
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
|
||||
mock_output = Mock()
|
||||
mock_forward.return_value = mock_output
|
||||
|
||||
# Set inference mode
|
||||
diffusion_model_instance.training = False
|
||||
|
||||
result = diffusion_model_instance.forward(
|
||||
input_ids=input_ids,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Should just call parent forward without diffusion processing
|
||||
assert mock_forward.called
|
||||
assert result == mock_output
|
||||
|
||||
Reference in New Issue
Block a user