diffusion alt: custom loss impl
This commit is contained in:
@@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin):
|
||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||
|
||||
def post_model_load(
|
||||
self, cfg: DictDefault, model: PreTrainedModel | PeftModel
|
||||
):
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Configure model for diffusion training after loading."""
|
||||
self.cfg = cfg
|
||||
|
||||
@@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin):
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
|
||||
"""Add diffusion generation callback if enabled."""
|
||||
if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples:
|
||||
if (
|
||||
hasattr(trainer, "diffusion_config")
|
||||
and trainer.diffusion_config.generate_samples
|
||||
):
|
||||
generation_callback = DiffusionGenerationCallback(trainer)
|
||||
LOG.info("Added diffusion generation callback")
|
||||
return [generation_callback]
|
||||
|
||||
@@ -229,14 +229,17 @@ class TestDiffusionPlugin:
|
||||
|
||||
def test_plugin_registers_loss_function(self):
|
||||
"""Test that plugin registers diffusion loss function."""
|
||||
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
|
||||
with patch(
|
||||
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
|
||||
return_value=True,
|
||||
) as mock_register:
|
||||
plugin = DiffusionPlugin()
|
||||
mock_register.assert_called_once()
|
||||
|
||||
def test_post_model_load_configuration(self):
|
||||
"""Test that post_model_load configures model correctly."""
|
||||
plugin = DiffusionPlugin()
|
||||
|
||||
|
||||
# Mock model and config
|
||||
mock_model = Mock()
|
||||
mock_model.config = Mock()
|
||||
@@ -244,18 +247,20 @@ class TestDiffusionPlugin:
|
||||
mock_cfg.eps = 1e-3
|
||||
mock_cfg.importance_weighting = True
|
||||
mock_cfg.mask_token_id = 32000
|
||||
|
||||
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
|
||||
) as mock_patch:
|
||||
result = plugin.post_model_load(mock_cfg, mock_model)
|
||||
|
||||
|
||||
# Check model configuration
|
||||
assert mock_model.config.loss_type == "ForDiffusionLM"
|
||||
assert mock_model.config.diffusion_config is not None
|
||||
assert mock_model.config.diffusion_config['eps'] == 1e-3
|
||||
|
||||
assert mock_model.config.diffusion_config["eps"] == 1e-3
|
||||
|
||||
# Check model was patched
|
||||
mock_patch.assert_called_once_with(mock_model)
|
||||
|
||||
|
||||
# Should return the model
|
||||
assert result == mock_model
|
||||
|
||||
@@ -264,15 +269,15 @@ class TestDiffusionPlugin:
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
|
||||
# Set config attributes
|
||||
for attr, value in diffusion_config.model_dump().items():
|
||||
setattr(mock_cfg, attr, value)
|
||||
|
||||
|
||||
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
||||
|
||||
|
||||
# Check that diffusion config was stored on trainer
|
||||
assert hasattr(mock_trainer, 'diffusion_config')
|
||||
assert hasattr(mock_trainer, "diffusion_config")
|
||||
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
||||
|
||||
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
||||
@@ -280,13 +285,15 @@ class TestDiffusionPlugin:
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
|
||||
# Mock trainer with diffusion config that has generation enabled
|
||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
||||
|
||||
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
|
||||
) as mock_callback_class:
|
||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||
|
||||
|
||||
# Should return one callback
|
||||
assert len(callbacks) == 1
|
||||
mock_callback_class.assert_called_once_with(mock_trainer)
|
||||
@@ -296,12 +303,12 @@ class TestDiffusionPlugin:
|
||||
plugin = DiffusionPlugin()
|
||||
mock_trainer = Mock()
|
||||
mock_cfg = Mock()
|
||||
|
||||
|
||||
# Mock trainer with diffusion config that has generation disabled
|
||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
|
||||
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||
|
||||
|
||||
# Should return no callbacks
|
||||
assert len(callbacks) == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user