* diffusion training plugin * cleanup * nits * fixes + improvements * add back in reinit_weights (clobbered?); masking / pretrain fixes * nits * cleanup; tests draft * sample generation, tests fixes * fixes * nits * add inference support; add auto-mask token support * nits * nits * progress * simplify logging * lint * prefix args with diffusion_ * coderabbito * tests fix * nit * nits * cleanup + nits * nits * fix SFT sample gen * fixes * fix * comments * comments * lint * reward model lora fix * cleanup; fix pretraining_dataset case * gradio inference * update cfgs * update cfgs * train, generation parity, cleanup * fix * simplify * test * test fix
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
"""Tests for diffusion generation callback dataloader selection and triggering."""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from axolotl.integrations.diffusion import DiffusionGenerationCallback
|
|
|
|
|
|
class DummyTrainer:
|
|
"""Minimal trainer double with required attributes/methods for the callback."""
|
|
|
|
def __init__(self, use_eval: bool):
|
|
# Config used by callback
|
|
self.cfg = SimpleNamespace(
|
|
diffusion=SimpleNamespace(
|
|
generation_interval=1,
|
|
num_generation_samples=1,
|
|
generation_max_length=32,
|
|
generation_steps=4,
|
|
generation_temperature=0.0,
|
|
mask_token_id=16,
|
|
),
|
|
use_wandb=False,
|
|
)
|
|
|
|
# Model/tokenizer are passed through to generate_samples; not used here
|
|
self.model = Mock()
|
|
self.processing_class = Mock()
|
|
|
|
# Datasets and loaders
|
|
self.eval_dataset = object() if use_eval else None
|
|
self._train_loader = object()
|
|
self._eval_loader = object()
|
|
|
|
# State for world process check
|
|
self.state = SimpleNamespace(is_world_process_zero=True)
|
|
|
|
# Track which loader was requested
|
|
self.requested: list[str] = []
|
|
|
|
def get_train_dataloader(self):
|
|
self.requested.append("train")
|
|
return self._train_loader
|
|
|
|
def get_eval_dataloader(self):
|
|
self.requested.append("eval")
|
|
return self._eval_loader
|
|
|
|
|
|
@pytest.mark.parametrize("use_eval", [False, True])
|
|
def test_callback_uses_correct_dataloader(monkeypatch, use_eval):
|
|
trainer = DummyTrainer(use_eval=use_eval)
|
|
callback = DiffusionGenerationCallback(trainer)
|
|
|
|
captured = {}
|
|
|
|
# Patch generate_samples in the callback module's namespace
|
|
def fake_generate_samples(**kwargs):
|
|
captured["dataloader"] = kwargs.get("dataloader")
|
|
# Return one dummy sample to exercise logging path
|
|
return [
|
|
{
|
|
"original": "o",
|
|
"masked": "m",
|
|
"generated": "g",
|
|
"mask_ratio": 0.5,
|
|
"masked_tokens": 1,
|
|
"total_tokens": 2,
|
|
}
|
|
]
|
|
|
|
monkeypatch.setattr(
|
|
"axolotl.integrations.diffusion.callbacks.generate_samples",
|
|
fake_generate_samples,
|
|
)
|
|
|
|
# Trigger at step 1 (interval=1)
|
|
args = SimpleNamespace()
|
|
state = SimpleNamespace(global_step=1)
|
|
control = SimpleNamespace()
|
|
|
|
callback.on_step_end(args=args, state=state, control=control)
|
|
|
|
# Assert the expected dataloader path was used
|
|
if use_eval:
|
|
assert trainer.requested[0] == "eval"
|
|
assert captured["dataloader"] is trainer._eval_loader
|
|
else:
|
|
assert trainer.requested[0] == "train"
|
|
assert captured["dataloader"] is trainer._train_loader
|