diffusion alt: custom loss impl

This commit is contained in:
Dan Saunders
2025-08-18 20:50:20 +00:00
parent 63d2280999
commit 260ebe4c93
5 changed files with 578 additions and 482 deletions

View File

@@ -2,111 +2,180 @@
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
import torch
from axolotl.integrations.diffusion.trainer import DiffusionTrainer
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
from axolotl.integrations.diffusion.args import DiffusionArgs
from axolotl.integrations.diffusion.loss import (
ForDiffusionLMLoss,
register_diffusion_loss,
)
from axolotl.integrations.diffusion.model_patch import (
_create_bidirectional_attention_mask,
_forward_process,
patch_model_for_bidirectional_attention,
)
from axolotl.integrations.diffusion.plugin import DiffusionPlugin
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return DictDefault(
{
"mask_token_id": 32000,
"eps": 1e-3,
"importance_weighting": False,
"sample_packing": False,
}
return DiffusionArgs(
eps=1e-3,
importance_weighting=False,
mask_token_id=32000,
generate_samples=False,
)
@pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.config = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
return trainer
def mock_model():
"""Create a mock model."""
model = Mock()
model.config = Mock()
model.config.loss_type = "ForDiffusionLM"
model.config.diffusion_config = {
"eps": 1e-3,
"importance_weighting": False,
"mask_token_id": 32000,
}
model.training = True
return model
class TestDiffusionTrainer:
"""Test the DiffusionTrainer class."""
class TestDiffusionLoss:
"""Test the ForDiffusionLMLoss function."""
def test_forward_process_basic(self, diffusion_trainer_instance):
"""Test basic forward process without labels."""
def test_loss_with_diffusion_info(self, mock_model):
"""Test loss computation with stored diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Mock stored diffusion info
original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
masked_indices = torch.tensor(
[[False, True, True, False, False]], dtype=torch.bool
)
p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
# Mock logits
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert loss.item() >= 0
def test_loss_fallback_without_diffusion_info(self, mock_model):
"""Test fallback to causal LM loss when no diffusion info."""
batch_size, seq_len, vocab_size = 1, 5, 1000
# Remove diffusion info to trigger fallback
if hasattr(mock_model, "_diffusion_info"):
delattr(mock_model, "_diffusion_info")
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_loss_no_masked_tokens(self, mock_model):
"""Test loss when no tokens are masked."""
batch_size, seq_len, vocab_size = 1, 3, 1000
# No masked tokens
original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long)
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
mock_model._diffusion_info = {
"original_input_ids": original_input_ids,
"masked_indices": masked_indices,
"p_mask": p_mask,
}
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.tensor([[1, 10, 2]], dtype=torch.long)
loss = ForDiffusionLMLoss(
logits=logits,
labels=labels,
vocab_size=vocab_size,
config=mock_model.config,
model=mock_model,
)
assert loss.item() == 0.0
class TestModelPatch:
"""Test the model patching functionality."""
def test_forward_process_basic(self):
"""Test basic forward process."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
noisy_input_ids, masked_indices, p_mask = _forward_process(
input_ids, diffusion_config=diffusion_config
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert noisy_input_ids.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied where masked
if masked_indices.any():
assert (noisy_input_ids[masked_indices] == 32000).all()
# Check that mask token is applied
mask_token_id = diffusion_trainer_instance._config.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance):
def test_forward_process_with_labels(self):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
_, masked_indices, _ = _forward_process(
input_ids, labels=labels, diffusion_config=diffusion_config
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
def test_forward_process_with_attention_mask(self):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
_, masked_indices, p_mask = _forward_process(
input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config
)
# Check that padding tokens are not masked
@@ -114,158 +183,146 @@ class TestDiffusionTrainer:
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
"""Test bidirectional attention mask without sample packing."""
def test_bidirectional_attention_mask(self):
"""Test bidirectional attention mask creation."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids
)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
# Should be all-to-all attention
expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance._config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
def test_bidirectional_attention_mask_with_padding(self):
"""Test bidirectional attention mask with padding."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids, attention_mask
)
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
# Check that tokens within same sample can attend to each other
# but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
# Padding positions should not attend or be attended to
assert not mask[0, 0, 3, :].any() # Padding can't attend to anything
assert not mask[0, 0, :, 3].any() # Nothing can attend to padding
def test_compute_loss_basic(self, diffusion_trainer_instance):
"""Test basic loss computation."""
# Mock model that returns logits
def test_patch_model_for_bidirectional_attention(self):
"""Test that model patching works."""
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.config = Mock()
mock_model.config.loss_type = "ForDiffusionLM"
mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000}
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
original_forward = Mock()
mock_model.forward = original_forward
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Patch the model
patch_model_for_bidirectional_attention(mock_model)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert outputs == mock_outputs
# Check that forward method was replaced
assert mock_model.forward != original_forward
# Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels."""
# Mock model
class TestDiffusionPlugin:
"""Test the DiffusionPlugin."""
def test_plugin_registers_loss_function(self):
"""Test that plugin registers diffusion loss function."""
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
plugin = DiffusionPlugin()
mock_register.assert_called_once()
def test_post_model_load_configuration(self):
"""Test that post_model_load configures model correctly."""
plugin = DiffusionPlugin()
# Mock model and config
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
mock_model.config = Mock()
mock_cfg = Mock()
mock_cfg.eps = 1e-3
mock_cfg.importance_weighting = True
mock_cfg.mask_token_id = 32000
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
result = plugin.post_model_load(mock_cfg, mock_model)
# Check model configuration
assert mock_model.config.loss_type == "ForDiffusionLM"
assert mock_model.config.diffusion_config is not None
assert mock_model.config.diffusion_config['eps'] == 1e-3
# Check model was patched
mock_patch.assert_called_once_with(mock_model)
# Should return the model
assert result == mock_model
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
def test_post_trainer_create_stores_config(self, diffusion_config):
"""Test that post_trainer_create stores config on trainer."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Set config attributes
for attr, value in diffusion_config.model_dump().items():
setattr(mock_cfg, attr, value)
plugin.post_trainer_create(mock_cfg, mock_trainer)
# Check that diffusion config was stored on trainer
assert hasattr(mock_trainer, 'diffusion_config')
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids, labels=labels
)
def test_add_callbacks_post_trainer_with_generation_enabled(self):
"""Test callback addition when generation is enabled."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation enabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return one callback
assert len(callbacks) == 1
mock_callback_class.assert_called_once_with(mock_trainer)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
def test_add_callbacks_post_trainer_with_generation_disabled(self):
"""Test callback addition when generation is disabled."""
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation disabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return no callbacks
assert len(callbacks) == 0
# Check that SFT metrics were added
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
class TestLossRegistration:
"""Test loss function registration."""
# Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
def test_register_diffusion_loss(self):
"""Test that loss function can be registered."""
with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping:
result = register_diffusion_loss()
assert result is True
assert "ForDiffusionLM" in mock_mapping
assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_trainer_instance):
"""Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_trainer_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.processing_class = None
trainer._cache_special_token_ids()
assert trainer._special_token_ids == set()
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test the main compute_loss interface."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = {
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
}
# Test without return_outputs
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
assert isinstance(loss, torch.Tensor)
# Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError."""
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs)
def test_register_diffusion_loss_import_error(self):
"""Test fallback when LOSS_MAPPING import fails."""
# Patch the import to raise ImportError
with patch(
"builtins.__import__",
side_effect=ImportError("transformers.loss.loss_utils not found"),
):
result = register_diffusion_loss()
assert result is False