diffusion alt: custom loss impl

This commit is contained in:
Dan Saunders
2025-08-18 20:50:34 +00:00
parent 260ebe4c93
commit 64f349b7bb
2 changed files with 31 additions and 23 deletions

View File

@@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin):
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load( def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
self, cfg: DictDefault, model: PreTrainedModel | PeftModel
):
"""Configure model for diffusion training after loading.""" """Configure model for diffusion training after loading."""
self.cfg = cfg self.cfg = cfg
@@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin):
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer): def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
"""Add diffusion generation callback if enabled.""" """Add diffusion generation callback if enabled."""
if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples: if (
hasattr(trainer, "diffusion_config")
and trainer.diffusion_config.generate_samples
):
generation_callback = DiffusionGenerationCallback(trainer) generation_callback = DiffusionGenerationCallback(trainer)
LOG.info("Added diffusion generation callback") LOG.info("Added diffusion generation callback")
return [generation_callback] return [generation_callback]

View File

@@ -229,7 +229,10 @@ class TestDiffusionPlugin:
def test_plugin_registers_loss_function(self): def test_plugin_registers_loss_function(self):
"""Test that plugin registers diffusion loss function.""" """Test that plugin registers diffusion loss function."""
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register: with patch(
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
return_value=True,
) as mock_register:
plugin = DiffusionPlugin() plugin = DiffusionPlugin()
mock_register.assert_called_once() mock_register.assert_called_once()
@@ -245,13 +248,15 @@ class TestDiffusionPlugin:
mock_cfg.importance_weighting = True mock_cfg.importance_weighting = True
mock_cfg.mask_token_id = 32000 mock_cfg.mask_token_id = 32000
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch: with patch(
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
) as mock_patch:
result = plugin.post_model_load(mock_cfg, mock_model) result = plugin.post_model_load(mock_cfg, mock_model)
# Check model configuration # Check model configuration
assert mock_model.config.loss_type == "ForDiffusionLM" assert mock_model.config.loss_type == "ForDiffusionLM"
assert mock_model.config.diffusion_config is not None assert mock_model.config.diffusion_config is not None
assert mock_model.config.diffusion_config['eps'] == 1e-3 assert mock_model.config.diffusion_config["eps"] == 1e-3
# Check model was patched # Check model was patched
mock_patch.assert_called_once_with(mock_model) mock_patch.assert_called_once_with(mock_model)
@@ -272,7 +277,7 @@ class TestDiffusionPlugin:
plugin.post_trainer_create(mock_cfg, mock_trainer) plugin.post_trainer_create(mock_cfg, mock_trainer)
# Check that diffusion config was stored on trainer # Check that diffusion config was stored on trainer
assert hasattr(mock_trainer, 'diffusion_config') assert hasattr(mock_trainer, "diffusion_config")
assert mock_trainer.diffusion_config.eps == diffusion_config.eps assert mock_trainer.diffusion_config.eps == diffusion_config.eps
def test_add_callbacks_post_trainer_with_generation_enabled(self): def test_add_callbacks_post_trainer_with_generation_enabled(self):
@@ -284,7 +289,9 @@ class TestDiffusionPlugin:
# Mock trainer with diffusion config that has generation enabled # Mock trainer with diffusion config that has generation enabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True) mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class: with patch(
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
) as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer) callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return one callback # Should return one callback