diffusion alt: custom loss impl
This commit is contained in:
@@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin):
|
|||||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||||
|
|
||||||
def post_model_load(
|
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
self, cfg: DictDefault, model: PreTrainedModel | PeftModel
|
|
||||||
):
|
|
||||||
"""Configure model for diffusion training after loading."""
|
"""Configure model for diffusion training after loading."""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
@@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin):
|
|||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
|
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
|
||||||
"""Add diffusion generation callback if enabled."""
|
"""Add diffusion generation callback if enabled."""
|
||||||
if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples:
|
if (
|
||||||
|
hasattr(trainer, "diffusion_config")
|
||||||
|
and trainer.diffusion_config.generate_samples
|
||||||
|
):
|
||||||
generation_callback = DiffusionGenerationCallback(trainer)
|
generation_callback = DiffusionGenerationCallback(trainer)
|
||||||
LOG.info("Added diffusion generation callback")
|
LOG.info("Added diffusion generation callback")
|
||||||
return [generation_callback]
|
return [generation_callback]
|
||||||
|
|||||||
@@ -229,7 +229,10 @@ class TestDiffusionPlugin:
|
|||||||
|
|
||||||
def test_plugin_registers_loss_function(self):
|
def test_plugin_registers_loss_function(self):
|
||||||
"""Test that plugin registers diffusion loss function."""
|
"""Test that plugin registers diffusion loss function."""
|
||||||
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
|
||||||
|
return_value=True,
|
||||||
|
) as mock_register:
|
||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
mock_register.assert_called_once()
|
mock_register.assert_called_once()
|
||||||
|
|
||||||
@@ -245,13 +248,15 @@ class TestDiffusionPlugin:
|
|||||||
mock_cfg.importance_weighting = True
|
mock_cfg.importance_weighting = True
|
||||||
mock_cfg.mask_token_id = 32000
|
mock_cfg.mask_token_id = 32000
|
||||||
|
|
||||||
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
|
||||||
|
) as mock_patch:
|
||||||
result = plugin.post_model_load(mock_cfg, mock_model)
|
result = plugin.post_model_load(mock_cfg, mock_model)
|
||||||
|
|
||||||
# Check model configuration
|
# Check model configuration
|
||||||
assert mock_model.config.loss_type == "ForDiffusionLM"
|
assert mock_model.config.loss_type == "ForDiffusionLM"
|
||||||
assert mock_model.config.diffusion_config is not None
|
assert mock_model.config.diffusion_config is not None
|
||||||
assert mock_model.config.diffusion_config['eps'] == 1e-3
|
assert mock_model.config.diffusion_config["eps"] == 1e-3
|
||||||
|
|
||||||
# Check model was patched
|
# Check model was patched
|
||||||
mock_patch.assert_called_once_with(mock_model)
|
mock_patch.assert_called_once_with(mock_model)
|
||||||
@@ -272,7 +277,7 @@ class TestDiffusionPlugin:
|
|||||||
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
||||||
|
|
||||||
# Check that diffusion config was stored on trainer
|
# Check that diffusion config was stored on trainer
|
||||||
assert hasattr(mock_trainer, 'diffusion_config')
|
assert hasattr(mock_trainer, "diffusion_config")
|
||||||
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
||||||
|
|
||||||
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
||||||
@@ -284,7 +289,9 @@ class TestDiffusionPlugin:
|
|||||||
# Mock trainer with diffusion config that has generation enabled
|
# Mock trainer with diffusion config that has generation enabled
|
||||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
||||||
|
|
||||||
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
|
||||||
|
) as mock_callback_class:
|
||||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||||
|
|
||||||
# Should return one callback
|
# Should return one callback
|
||||||
|
|||||||
Reference in New Issue
Block a user