diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py index d22b277ea..dae192a33 100644 --- a/src/axolotl/integrations/diffusion/plugin.py +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin): """Returns the pydantic model for LLaDA plugin arguments.""" return "axolotl.integrations.diffusion.DiffusionArgs" - def post_model_load( - self, cfg: DictDefault, model: PreTrainedModel | PeftModel - ): + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Configure model for diffusion training after loading.""" self.cfg = cfg @@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin): def add_callbacks_post_trainer(self, cfg: DictDefault, trainer): """Add diffusion generation callback if enabled.""" - if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples: + if ( + hasattr(trainer, "diffusion_config") + and trainer.diffusion_config.generate_samples + ): generation_callback = DiffusionGenerationCallback(trainer) LOG.info("Added diffusion generation callback") return [generation_callback] diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py index 37709e349..4436caac6 100644 --- a/tests/integrations/test_diffusion.py +++ b/tests/integrations/test_diffusion.py @@ -229,14 +229,17 @@ class TestDiffusionPlugin: def test_plugin_registers_loss_function(self): """Test that plugin registers diffusion loss function.""" - with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register: + with patch( + "axolotl.integrations.diffusion.plugin.register_diffusion_loss", + return_value=True, + ) as mock_register: plugin = DiffusionPlugin() mock_register.assert_called_once() def test_post_model_load_configuration(self): """Test that post_model_load configures model correctly.""" plugin = DiffusionPlugin() - + # Mock model and config mock_model = Mock() mock_model.config = Mock() @@ -244,18 +247,20 @@ class TestDiffusionPlugin: mock_cfg.eps = 1e-3 mock_cfg.importance_weighting = True mock_cfg.mask_token_id = 32000 - - with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch: + + with patch( + "axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention" + ) as mock_patch: result = plugin.post_model_load(mock_cfg, mock_model) - + # Check model configuration assert mock_model.config.loss_type == "ForDiffusionLM" assert mock_model.config.diffusion_config is not None - assert mock_model.config.diffusion_config['eps'] == 1e-3 - + assert mock_model.config.diffusion_config["eps"] == 1e-3 + # Check model was patched mock_patch.assert_called_once_with(mock_model) - + # Should return the model assert result == mock_model @@ -264,15 +269,15 @@ class TestDiffusionPlugin: plugin = DiffusionPlugin() mock_trainer = Mock() mock_cfg = Mock() - + # Set config attributes for attr, value in diffusion_config.model_dump().items(): setattr(mock_cfg, attr, value) - + plugin.post_trainer_create(mock_cfg, mock_trainer) - + # Check that diffusion config was stored on trainer - assert hasattr(mock_trainer, 'diffusion_config') + assert hasattr(mock_trainer, "diffusion_config") assert mock_trainer.diffusion_config.eps == diffusion_config.eps def test_add_callbacks_post_trainer_with_generation_enabled(self): @@ -280,13 +285,15 @@ class TestDiffusionPlugin: plugin = DiffusionPlugin() mock_trainer = Mock() mock_cfg = Mock() - + # Mock trainer with diffusion config that has generation enabled mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True) - - with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class: + + with patch( + "axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback" + ) as mock_callback_class: callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) - + # Should return one callback assert len(callbacks) == 1 mock_callback_class.assert_called_once_with(mock_trainer) @@ -296,12 +303,12 @@ class TestDiffusionPlugin: plugin = DiffusionPlugin() mock_trainer = Mock() mock_cfg = Mock() - + # Mock trainer with diffusion config that has generation disabled mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False) - + callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) - + # Should return no callbacks assert len(callbacks) == 0