text diffusion training plugin (#3067)
* diffusion training plugin * cleanup * nits * fixes + improvements * add back in reinit_weights (clobbered?); masking / pretrain fixes * nits * cleanup; tests draft * sample generation, tests fixes * fixes * nits * add inference support; add auto-mask token support * nits * nits * progress * simplify logging * lint * prefix args with diffusion_ * coderabbito * tests fix * nit * nits * cleanup + nits * nits * fix SFT sample gen * fixes * fix * comments * comments * lint * reward model lora fix * cleanup; fix pretraining_dataset case * gradio inference * update cfgs * update cfgs * train, generation parity, cleanup * fix * simplify * test * test fix
This commit is contained in:
92
tests/integrations/test_diffusion_callback.py
Normal file
92
tests/integrations/test_diffusion_callback.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for diffusion generation callback dataloader selection and triggering."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.integrations.diffusion import DiffusionGenerationCallback
|
||||
|
||||
|
||||
class DummyTrainer:
|
||||
"""Minimal trainer double with required attributes/methods for the callback."""
|
||||
|
||||
def __init__(self, use_eval: bool):
|
||||
# Config used by callback
|
||||
self.cfg = SimpleNamespace(
|
||||
diffusion=SimpleNamespace(
|
||||
generation_interval=1,
|
||||
num_generation_samples=1,
|
||||
generation_max_length=32,
|
||||
generation_steps=4,
|
||||
generation_temperature=0.0,
|
||||
mask_token_id=16,
|
||||
),
|
||||
use_wandb=False,
|
||||
)
|
||||
|
||||
# Model/tokenizer are passed through to generate_samples; not used here
|
||||
self.model = Mock()
|
||||
self.processing_class = Mock()
|
||||
|
||||
# Datasets and loaders
|
||||
self.eval_dataset = object() if use_eval else None
|
||||
self._train_loader = object()
|
||||
self._eval_loader = object()
|
||||
|
||||
# State for world process check
|
||||
self.state = SimpleNamespace(is_world_process_zero=True)
|
||||
|
||||
# Track which loader was requested
|
||||
self.requested: list[str] = []
|
||||
|
||||
def get_train_dataloader(self):
|
||||
self.requested.append("train")
|
||||
return self._train_loader
|
||||
|
||||
def get_eval_dataloader(self):
|
||||
self.requested.append("eval")
|
||||
return self._eval_loader
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eval", [False, True])
|
||||
def test_callback_uses_correct_dataloader(monkeypatch, use_eval):
|
||||
trainer = DummyTrainer(use_eval=use_eval)
|
||||
callback = DiffusionGenerationCallback(trainer)
|
||||
|
||||
captured = {}
|
||||
|
||||
# Patch generate_samples in the callback module's namespace
|
||||
def fake_generate_samples(**kwargs):
|
||||
captured["dataloader"] = kwargs.get("dataloader")
|
||||
# Return one dummy sample to exercise logging path
|
||||
return [
|
||||
{
|
||||
"original": "o",
|
||||
"masked": "m",
|
||||
"generated": "g",
|
||||
"mask_ratio": 0.5,
|
||||
"masked_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.integrations.diffusion.callbacks.generate_samples",
|
||||
fake_generate_samples,
|
||||
)
|
||||
|
||||
# Trigger at step 1 (interval=1)
|
||||
args = SimpleNamespace()
|
||||
state = SimpleNamespace(global_step=1)
|
||||
control = SimpleNamespace()
|
||||
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
|
||||
# Assert the expected dataloader path was used
|
||||
if use_eval:
|
||||
assert trainer.requested[0] == "eval"
|
||||
assert captured["dataloader"] is trainer._eval_loader
|
||||
else:
|
||||
assert trainer.requested[0] == "train"
|
||||
assert captured["dataloader"] is trainer._train_loader
|
||||
Reference in New Issue
Block a user