diffusion alt: custom loss impl
This commit is contained in:
@@ -2,111 +2,180 @@
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.diffusion.trainer import DiffusionTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
"""Create a mock tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.pad_token_id = 0
|
||||
return tokenizer
|
||||
from axolotl.integrations.diffusion.args import DiffusionArgs
|
||||
from axolotl.integrations.diffusion.loss import (
|
||||
ForDiffusionLMLoss,
|
||||
register_diffusion_loss,
|
||||
)
|
||||
from axolotl.integrations.diffusion.model_patch import (
|
||||
_create_bidirectional_attention_mask,
|
||||
_forward_process,
|
||||
patch_model_for_bidirectional_attention,
|
||||
)
|
||||
from axolotl.integrations.diffusion.plugin import DiffusionPlugin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusion_config():
|
||||
"""Create a diffusion config."""
|
||||
return DictDefault(
|
||||
{
|
||||
"mask_token_id": 32000,
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": False,
|
||||
"sample_packing": False,
|
||||
}
|
||||
return DiffusionArgs(
|
||||
eps=1e-3,
|
||||
importance_weighting=False,
|
||||
mask_token_id=32000,
|
||||
generate_samples=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
|
||||
"""Create a diffusion trainer instance for testing methods directly."""
|
||||
# Create a minimal trainer instance just for testing methods
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.config = diffusion_config
|
||||
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||
trainer.processing_class = mock_tokenizer
|
||||
trainer.store_metrics = Mock() # Mock metrics storage
|
||||
return trainer
|
||||
def mock_model():
|
||||
"""Create a mock model."""
|
||||
model = Mock()
|
||||
model.config = Mock()
|
||||
model.config.loss_type = "ForDiffusionLM"
|
||||
model.config.diffusion_config = {
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": False,
|
||||
"mask_token_id": 32000,
|
||||
}
|
||||
model.training = True
|
||||
return model
|
||||
|
||||
|
||||
class TestDiffusionTrainer:
|
||||
"""Test the DiffusionTrainer class."""
|
||||
class TestDiffusionLoss:
|
||||
"""Test the ForDiffusionLMLoss function."""
|
||||
|
||||
def test_forward_process_basic(self, diffusion_trainer_instance):
|
||||
"""Test basic forward process without labels."""
|
||||
def test_loss_with_diffusion_info(self, mock_model):
|
||||
"""Test loss computation with stored diffusion info."""
|
||||
batch_size, seq_len, vocab_size = 1, 5, 1000
|
||||
|
||||
# Mock stored diffusion info
|
||||
original_input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
masked_indices = torch.tensor(
|
||||
[[False, True, True, False, False]], dtype=torch.bool
|
||||
)
|
||||
p_mask = torch.tensor([[0.5, 0.5, 0.5, 0.5, 0.5]], dtype=torch.float)
|
||||
|
||||
mock_model._diffusion_info = {
|
||||
"original_input_ids": original_input_ids,
|
||||
"masked_indices": masked_indices,
|
||||
"p_mask": p_mask,
|
||||
}
|
||||
|
||||
# Mock logits
|
||||
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
loss = ForDiffusionLMLoss(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=vocab_size,
|
||||
config=mock_model.config,
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
assert loss.item() >= 0
|
||||
|
||||
def test_loss_fallback_without_diffusion_info(self, mock_model):
|
||||
"""Test fallback to causal LM loss when no diffusion info."""
|
||||
batch_size, seq_len, vocab_size = 1, 5, 1000
|
||||
|
||||
# Remove diffusion info to trigger fallback
|
||||
if hasattr(mock_model, "_diffusion_info"):
|
||||
delattr(mock_model, "_diffusion_info")
|
||||
|
||||
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
|
||||
labels = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
loss = ForDiffusionLMLoss(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=vocab_size,
|
||||
config=mock_model.config,
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
|
||||
def test_loss_no_masked_tokens(self, mock_model):
|
||||
"""Test loss when no tokens are masked."""
|
||||
batch_size, seq_len, vocab_size = 1, 3, 1000
|
||||
|
||||
# No masked tokens
|
||||
original_input_ids = torch.tensor([[1, 10, 2]], dtype=torch.long)
|
||||
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
|
||||
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
|
||||
|
||||
mock_model._diffusion_info = {
|
||||
"original_input_ids": original_input_ids,
|
||||
"masked_indices": masked_indices,
|
||||
"p_mask": p_mask,
|
||||
}
|
||||
|
||||
logits = torch.randn(batch_size, seq_len, vocab_size)
|
||||
labels = torch.tensor([[1, 10, 2]], dtype=torch.long)
|
||||
|
||||
loss = ForDiffusionLMLoss(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=vocab_size,
|
||||
config=mock_model.config,
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
class TestModelPatch:
|
||||
"""Test the model patching functionality."""
|
||||
|
||||
def test_forward_process_basic(self):
|
||||
"""Test basic forward process."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
|
||||
noisy_input_ids, masked_indices, p_mask = _forward_process(
|
||||
input_ids, diffusion_config=diffusion_config
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert noisy_batch.shape == input_ids.shape
|
||||
assert noisy_input_ids.shape == input_ids.shape
|
||||
assert masked_indices.shape == input_ids.shape
|
||||
assert p_mask.shape == input_ids.shape
|
||||
|
||||
# Check that special tokens are not masked
|
||||
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
|
||||
assert not masked_indices[special_token_positions].any()
|
||||
# Check that mask token is applied where masked
|
||||
if masked_indices.any():
|
||||
assert (noisy_input_ids[masked_indices] == 32000).all()
|
||||
|
||||
# Check that mask token is applied
|
||||
mask_token_id = diffusion_trainer_instance._config.mask_token_id
|
||||
masked_positions = masked_indices
|
||||
if masked_positions.any():
|
||||
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
||||
|
||||
def test_forward_process_with_labels(self, diffusion_trainer_instance):
|
||||
def test_forward_process_with_labels(self):
|
||||
"""Test forward process with SFT labels."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(
|
||||
input_ids, labels=labels, eps=0.1
|
||||
)
|
||||
_, masked_indices, _ = _forward_process(
|
||||
input_ids, labels=labels, diffusion_config=diffusion_config
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert noisy_batch.shape == input_ids.shape
|
||||
assert masked_indices.shape == input_ids.shape
|
||||
assert p_mask.shape == input_ids.shape
|
||||
|
||||
# Check that only answer tokens can be masked (where labels != -100)
|
||||
non_answer_mask = labels == -100
|
||||
|
||||
# No masking should occur on non-answer tokens
|
||||
assert not masked_indices[non_answer_mask].any()
|
||||
|
||||
# p_mask should be the same for all positions (sampled timestep),
|
||||
# but masking is only applied to answer tokens
|
||||
assert p_mask.shape == input_ids.shape
|
||||
# Verify that masked_indices respects the answer mask
|
||||
assert not masked_indices[non_answer_mask].any()
|
||||
|
||||
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
|
||||
def test_forward_process_with_attention_mask(self):
|
||||
"""Test forward process with attention mask."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
|
||||
diffusion_config = {"eps": 0.1, "mask_token_id": 32000}
|
||||
|
||||
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
|
||||
input_ids, attention_mask=attention_mask, eps=0.1
|
||||
_, masked_indices, p_mask = _forward_process(
|
||||
input_ids, attention_mask=attention_mask, diffusion_config=diffusion_config
|
||||
)
|
||||
|
||||
# Check that padding tokens are not masked
|
||||
@@ -114,158 +183,146 @@ class TestDiffusionTrainer:
|
||||
assert not masked_indices[padding_positions].any()
|
||||
assert (p_mask[padding_positions] == 0).all()
|
||||
|
||||
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
|
||||
"""Test bidirectional attention mask without sample packing."""
|
||||
def test_bidirectional_attention_mask(self):
|
||||
"""Test bidirectional attention mask creation."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)
|
||||
|
||||
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
|
||||
input_ids
|
||||
)
|
||||
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
|
||||
|
||||
# Should be all-to-all attention
|
||||
expected_shape = (1, 1, 4, 4)
|
||||
assert mask.shape == expected_shape
|
||||
assert mask.all()
|
||||
|
||||
def test_bidirectional_attention_mask_with_packing(
|
||||
self, diffusion_trainer_instance
|
||||
):
|
||||
"""Test bidirectional attention mask with sample packing."""
|
||||
diffusion_trainer_instance._config.sample_packing = True
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
||||
# Sample IDs: first sample (1), second sample (2)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
||||
def test_bidirectional_attention_mask_with_padding(self):
|
||||
"""Test bidirectional attention mask with padding."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
|
||||
|
||||
mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
mask = _create_bidirectional_attention_mask(input_ids, attention_mask)
|
||||
|
||||
# Check that tokens within same sample can attend to each other
|
||||
# but not across samples
|
||||
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
|
||||
assert mask[0, 0, 1, 2].item()
|
||||
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
|
||||
assert not mask[0, 0, 2, 4].item()
|
||||
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
|
||||
# Padding positions should not attend or be attended to
|
||||
assert not mask[0, 0, 3, :].any() # Padding can't attend to anything
|
||||
assert not mask[0, 0, :, 3].any() # Nothing can attend to padding
|
||||
|
||||
def test_compute_loss_basic(self, diffusion_trainer_instance):
|
||||
"""Test basic loss computation."""
|
||||
# Mock model that returns logits
|
||||
def test_patch_model_for_bidirectional_attention(self):
|
||||
"""Test that model patching works."""
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.config = Mock()
|
||||
mock_model.config.loss_type = "ForDiffusionLM"
|
||||
mock_model.config.diffusion_config = {"eps": 1e-3, "mask_token_id": 32000}
|
||||
mock_model.training = True
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
original_forward = Mock()
|
||||
mock_model.forward = original_forward
|
||||
|
||||
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
)
|
||||
# Patch the model
|
||||
patch_model_for_bidirectional_attention(mock_model)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
assert outputs == mock_outputs
|
||||
# Check that forward method was replaced
|
||||
assert mock_model.forward != original_forward
|
||||
|
||||
# Check that metrics were stored
|
||||
diffusion_trainer_instance.store_metrics.assert_called_once()
|
||||
|
||||
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
|
||||
"""Test loss computation with SFT labels."""
|
||||
# Mock model
|
||||
class TestDiffusionPlugin:
|
||||
"""Test the DiffusionPlugin."""
|
||||
|
||||
def test_plugin_registers_loss_function(self):
|
||||
"""Test that plugin registers diffusion loss function."""
|
||||
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
|
||||
plugin = DiffusionPlugin()
|
||||
mock_register.assert_called_once()
|
||||
|
||||
def test_post_model_load_configuration(self):
|
||||
"""Test that post_model_load configures model correctly."""
|
||||
plugin = DiffusionPlugin()
|
||||
|
||||
# Mock model and config
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
mock_model.config = Mock()
|
||||
mock_cfg = Mock()
|
||||
mock_cfg.eps = 1e-3
|
||||
mock_cfg.importance_weighting = True
|
||||
mock_cfg.mask_token_id = 32000
|
||||
|
||||
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
|
||||
result = plugin.post_model_load(mock_cfg, mock_model)
|
||||
|
||||
# Check model configuration
|
||||
assert mock_model.config.loss_type == "ForDiffusionLM"
|
||||
assert mock_model.config.diffusion_config is not None
|
||||
assert mock_model.config.diffusion_config['eps'] == 1e-3
|
||||
|
||||
# Check model was patched
|
||||
mock_patch.assert_called_once_with(mock_model)
|
||||
|
||||
# Should return the model
|
||||
assert result == mock_model
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
def test_post_trainer_create_stores_config(self, diffusion_config):
|
||||
"""Test that post_trainer_create stores config on trainer."""
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
# Set config attributes
|
||||
for attr, value in diffusion_config.model_dump().items():
|
||||
setattr(mock_cfg, attr, value)
|
||||
|
||||
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
||||
|
||||
# Check that diffusion config was stored on trainer
|
||||
assert hasattr(mock_trainer, 'diffusion_config')
|
||||
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids, labels=labels
|
||||
)
|
||||
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
||||
"""Test callback addition when generation is enabled."""
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
# Mock trainer with diffusion config that has generation enabled
|
||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
||||
|
||||
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
|
||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||
|
||||
# Should return one callback
|
||||
assert len(callbacks) == 1
|
||||
mock_callback_class.assert_called_once_with(mock_trainer)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
def test_add_callbacks_post_trainer_with_generation_disabled(self):
|
||||
"""Test callback addition when generation is disabled."""
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
# Mock trainer with diffusion config that has generation disabled
|
||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||
|
||||
# Should return no callbacks
|
||||
assert len(callbacks) == 0
|
||||
|
||||
# Check that SFT metrics were added
|
||||
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
|
||||
assert "answer_ratio" in call_args
|
||||
assert "avg_answer_length" in call_args
|
||||
|
||||
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
|
||||
"""Test loss computation when no tokens are masked."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 3
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
class TestLossRegistration:
|
||||
"""Test loss function registration."""
|
||||
|
||||
# Only special tokens (which won't be masked)
|
||||
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
|
||||
def test_register_diffusion_loss(self):
|
||||
"""Test that loss function can be registered."""
|
||||
with patch("transformers.loss.loss_utils.LOSS_MAPPING", {}) as mock_mapping:
|
||||
result = register_diffusion_loss()
|
||||
assert result is True
|
||||
assert "ForDiffusionLM" in mock_mapping
|
||||
assert mock_mapping["ForDiffusionLM"] == ForDiffusionLMLoss
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
)
|
||||
|
||||
# Loss should be zero when no tokens are masked
|
||||
assert loss.item() == 0.0
|
||||
assert loss.requires_grad
|
||||
|
||||
def test_cache_special_token_ids(self, diffusion_trainer_instance):
|
||||
"""Test caching of special token IDs."""
|
||||
# Should cache BOS, EOS, PAD tokens
|
||||
expected_tokens = {0, 1, 2} # pad, bos, eos
|
||||
assert diffusion_trainer_instance._special_token_ids == expected_tokens
|
||||
|
||||
def test_cache_special_token_ids_no_tokenizer(self):
|
||||
"""Test caching when no tokenizer is available."""
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.processing_class = None
|
||||
trainer._cache_special_token_ids()
|
||||
|
||||
assert trainer._special_token_ids == set()
|
||||
|
||||
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
|
||||
"""Test the main compute_loss interface."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
mock_outputs.logits = torch.randn(1, 5, 1000)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
|
||||
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
|
||||
}
|
||||
|
||||
# Test without return_outputs
|
||||
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test with return_outputs
|
||||
loss, outputs = diffusion_trainer_instance.compute_loss(
|
||||
mock_model, inputs, return_outputs=True
|
||||
)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert outputs == mock_outputs
|
||||
|
||||
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
|
||||
"""Test that missing input_ids raises ValueError."""
|
||||
mock_model = Mock()
|
||||
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
|
||||
|
||||
with pytest.raises(ValueError, match="input_ids is required"):
|
||||
diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
def test_register_diffusion_loss_import_error(self):
|
||||
"""Test fallback when LOSS_MAPPING import fails."""
|
||||
# Patch the import to raise ImportError
|
||||
with patch(
|
||||
"builtins.__import__",
|
||||
side_effect=ImportError("transformers.loss.loss_utils not found"),
|
||||
):
|
||||
result = register_diffusion_loss()
|
||||
assert result is False
|
||||
|
||||
Reference in New Issue
Block a user