diffusion alt: custom loss impl

This commit is contained in:
Dan Saunders
2025-08-18 20:50:34 +00:00
parent 260ebe4c93
commit 64f349b7bb
2 changed files with 31 additions and 23 deletions

View File

@@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin):
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(
self, cfg: DictDefault, model: PreTrainedModel | PeftModel
):
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model for diffusion training after loading."""
self.cfg = cfg
@@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin):
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
"""Add diffusion generation callback if enabled."""
if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples:
if (
hasattr(trainer, "diffusion_config")
and trainer.diffusion_config.generate_samples
):
generation_callback = DiffusionGenerationCallback(trainer)
LOG.info("Added diffusion generation callback")
return [generation_callback]

View File

@@ -229,14 +229,17 @@ class TestDiffusionPlugin:
def test_plugin_registers_loss_function(self):
"""Test that plugin registers diffusion loss function."""
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
with patch(
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
return_value=True,
) as mock_register:
plugin = DiffusionPlugin()
mock_register.assert_called_once()
def test_post_model_load_configuration(self):
"""Test that post_model_load configures model correctly."""
plugin = DiffusionPlugin()
# Mock model and config
mock_model = Mock()
mock_model.config = Mock()
@@ -244,18 +247,20 @@ class TestDiffusionPlugin:
mock_cfg.eps = 1e-3
mock_cfg.importance_weighting = True
mock_cfg.mask_token_id = 32000
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
with patch(
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
) as mock_patch:
result = plugin.post_model_load(mock_cfg, mock_model)
# Check model configuration
assert mock_model.config.loss_type == "ForDiffusionLM"
assert mock_model.config.diffusion_config is not None
assert mock_model.config.diffusion_config['eps'] == 1e-3
assert mock_model.config.diffusion_config["eps"] == 1e-3
# Check model was patched
mock_patch.assert_called_once_with(mock_model)
# Should return the model
assert result == mock_model
@@ -264,15 +269,15 @@ class TestDiffusionPlugin:
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Set config attributes
for attr, value in diffusion_config.model_dump().items():
setattr(mock_cfg, attr, value)
plugin.post_trainer_create(mock_cfg, mock_trainer)
# Check that diffusion config was stored on trainer
assert hasattr(mock_trainer, 'diffusion_config')
assert hasattr(mock_trainer, "diffusion_config")
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
def test_add_callbacks_post_trainer_with_generation_enabled(self):
@@ -280,13 +285,15 @@ class TestDiffusionPlugin:
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation enabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
with patch(
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
) as mock_callback_class:
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return one callback
assert len(callbacks) == 1
mock_callback_class.assert_called_once_with(mock_trainer)
@@ -296,12 +303,12 @@ class TestDiffusionPlugin:
plugin = DiffusionPlugin()
mock_trainer = Mock()
mock_cfg = Mock()
# Mock trainer with diffusion config that has generation disabled
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
# Should return no callbacks
assert len(callbacks) == 0