diffusion alt: custom loss impl
This commit is contained in:
@@ -38,9 +38,7 @@ class DiffusionPlugin(BasePlugin):
|
|||||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||||
|
|
||||||
def post_model_load(
|
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
self, cfg: DictDefault, model: PreTrainedModel | PeftModel
|
|
||||||
):
|
|
||||||
"""Configure model for diffusion training after loading."""
|
"""Configure model for diffusion training after loading."""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
@@ -88,7 +86,10 @@ class DiffusionPlugin(BasePlugin):
|
|||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
|
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer):
|
||||||
"""Add diffusion generation callback if enabled."""
|
"""Add diffusion generation callback if enabled."""
|
||||||
if hasattr(trainer, 'diffusion_config') and trainer.diffusion_config.generate_samples:
|
if (
|
||||||
|
hasattr(trainer, "diffusion_config")
|
||||||
|
and trainer.diffusion_config.generate_samples
|
||||||
|
):
|
||||||
generation_callback = DiffusionGenerationCallback(trainer)
|
generation_callback = DiffusionGenerationCallback(trainer)
|
||||||
LOG.info("Added diffusion generation callback")
|
LOG.info("Added diffusion generation callback")
|
||||||
return [generation_callback]
|
return [generation_callback]
|
||||||
|
|||||||
@@ -229,14 +229,17 @@ class TestDiffusionPlugin:
|
|||||||
|
|
||||||
def test_plugin_registers_loss_function(self):
|
def test_plugin_registers_loss_function(self):
|
||||||
"""Test that plugin registers diffusion loss function."""
|
"""Test that plugin registers diffusion loss function."""
|
||||||
with patch('axolotl.integrations.diffusion.plugin.register_diffusion_loss', return_value=True) as mock_register:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.register_diffusion_loss",
|
||||||
|
return_value=True,
|
||||||
|
) as mock_register:
|
||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
mock_register.assert_called_once()
|
mock_register.assert_called_once()
|
||||||
|
|
||||||
def test_post_model_load_configuration(self):
|
def test_post_model_load_configuration(self):
|
||||||
"""Test that post_model_load configures model correctly."""
|
"""Test that post_model_load configures model correctly."""
|
||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
|
|
||||||
# Mock model and config
|
# Mock model and config
|
||||||
mock_model = Mock()
|
mock_model = Mock()
|
||||||
mock_model.config = Mock()
|
mock_model.config = Mock()
|
||||||
@@ -244,18 +247,20 @@ class TestDiffusionPlugin:
|
|||||||
mock_cfg.eps = 1e-3
|
mock_cfg.eps = 1e-3
|
||||||
mock_cfg.importance_weighting = True
|
mock_cfg.importance_weighting = True
|
||||||
mock_cfg.mask_token_id = 32000
|
mock_cfg.mask_token_id = 32000
|
||||||
|
|
||||||
with patch('axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention') as mock_patch:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.patch_model_for_bidirectional_attention"
|
||||||
|
) as mock_patch:
|
||||||
result = plugin.post_model_load(mock_cfg, mock_model)
|
result = plugin.post_model_load(mock_cfg, mock_model)
|
||||||
|
|
||||||
# Check model configuration
|
# Check model configuration
|
||||||
assert mock_model.config.loss_type == "ForDiffusionLM"
|
assert mock_model.config.loss_type == "ForDiffusionLM"
|
||||||
assert mock_model.config.diffusion_config is not None
|
assert mock_model.config.diffusion_config is not None
|
||||||
assert mock_model.config.diffusion_config['eps'] == 1e-3
|
assert mock_model.config.diffusion_config["eps"] == 1e-3
|
||||||
|
|
||||||
# Check model was patched
|
# Check model was patched
|
||||||
mock_patch.assert_called_once_with(mock_model)
|
mock_patch.assert_called_once_with(mock_model)
|
||||||
|
|
||||||
# Should return the model
|
# Should return the model
|
||||||
assert result == mock_model
|
assert result == mock_model
|
||||||
|
|
||||||
@@ -264,15 +269,15 @@ class TestDiffusionPlugin:
|
|||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
mock_trainer = Mock()
|
mock_trainer = Mock()
|
||||||
mock_cfg = Mock()
|
mock_cfg = Mock()
|
||||||
|
|
||||||
# Set config attributes
|
# Set config attributes
|
||||||
for attr, value in diffusion_config.model_dump().items():
|
for attr, value in diffusion_config.model_dump().items():
|
||||||
setattr(mock_cfg, attr, value)
|
setattr(mock_cfg, attr, value)
|
||||||
|
|
||||||
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
plugin.post_trainer_create(mock_cfg, mock_trainer)
|
||||||
|
|
||||||
# Check that diffusion config was stored on trainer
|
# Check that diffusion config was stored on trainer
|
||||||
assert hasattr(mock_trainer, 'diffusion_config')
|
assert hasattr(mock_trainer, "diffusion_config")
|
||||||
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
assert mock_trainer.diffusion_config.eps == diffusion_config.eps
|
||||||
|
|
||||||
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
def test_add_callbacks_post_trainer_with_generation_enabled(self):
|
||||||
@@ -280,13 +285,15 @@ class TestDiffusionPlugin:
|
|||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
mock_trainer = Mock()
|
mock_trainer = Mock()
|
||||||
mock_cfg = Mock()
|
mock_cfg = Mock()
|
||||||
|
|
||||||
# Mock trainer with diffusion config that has generation enabled
|
# Mock trainer with diffusion config that has generation enabled
|
||||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=True)
|
||||||
|
|
||||||
with patch('axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback') as mock_callback_class:
|
with patch(
|
||||||
|
"axolotl.integrations.diffusion.plugin.DiffusionGenerationCallback"
|
||||||
|
) as mock_callback_class:
|
||||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||||
|
|
||||||
# Should return one callback
|
# Should return one callback
|
||||||
assert len(callbacks) == 1
|
assert len(callbacks) == 1
|
||||||
mock_callback_class.assert_called_once_with(mock_trainer)
|
mock_callback_class.assert_called_once_with(mock_trainer)
|
||||||
@@ -296,12 +303,12 @@ class TestDiffusionPlugin:
|
|||||||
plugin = DiffusionPlugin()
|
plugin = DiffusionPlugin()
|
||||||
mock_trainer = Mock()
|
mock_trainer = Mock()
|
||||||
mock_cfg = Mock()
|
mock_cfg = Mock()
|
||||||
|
|
||||||
# Mock trainer with diffusion config that has generation disabled
|
# Mock trainer with diffusion config that has generation disabled
|
||||||
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
|
mock_trainer.diffusion_config = DiffusionArgs(generate_samples=False)
|
||||||
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
callbacks = plugin.add_callbacks_post_trainer(mock_cfg, mock_trainer)
|
||||||
|
|
||||||
# Should return no callbacks
|
# Should return no callbacks
|
||||||
assert len(callbacks) == 0
|
assert len(callbacks) == 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user