Files
axolotl/tests/integrations/test_scattermoe_autotune_telemetry.py
Wing Lian f291ac029c fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci]
* fix for flaky tests in lora ops kernels w autotune

* attempt 2 to fix
2026-03-19 01:18:47 -04:00

399 lines
14 KiB
Python

"""Tests for scattermoe autotune telemetry integration.
These tests use mocking to verify the collection and reporting logic
without requiring Triton or CUDA.
"""
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
# Patch target for _find_lora_ops_module inside the collector module.
_FIND_MODULE_PATH = (
"axolotl.integrations.kernels.autotune_collector._find_lora_ops_module"
)
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
"""Create a mock triton.Config-like object."""
return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)
def _make_mock_kernel(cache=None):
"""Create a mock autotuned kernel object with a ``.cache`` dict."""
kernel = SimpleNamespace()
kernel.cache = cache if cache is not None else {}
return kernel
def _make_mock_lora_ops(
fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None
):
"""Build a mock ``lora_ops`` module with the four kernel attributes."""
mod = SimpleNamespace(
_scatter2scatter_lora=_make_mock_kernel(fwd_cache),
_scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),
_group_bwd_lora=_make_mock_kernel(bwd_cache),
_group_bwd_lora_fused=_make_mock_kernel(fused_cache),
)
return mod
def _real_lora_ops_module_names():
"""Return sys.modules keys that match the lora_ops discovery pattern.
Other tests in the same xdist worker may have loaded the *real*
lora_ops module. We need to temporarily hide those entries so the
discovery test finds only the mock we inject.
"""
return [
name
for name, mod in list(sys.modules.items())
if mod is not None
and "lora_ops" in name
and hasattr(mod, "_scatter2scatter_lora")
]
# =========================================================================
# TestAutotuneCollector
# =========================================================================
class TestAutotuneCollector:
"""Test ``collect_autotune_configs`` with mocked kernel objects.
Collection tests patch ``_find_lora_ops_module`` directly so they are
not affected by real ``lora_ops`` modules that other tests in the same
pytest-xdist worker may have loaded into ``sys.modules``.
"""
def test_empty_cache_returns_empty_list(self):
"""When no kernel has been autotuned yet, return ``[]``."""
mock_lora_ops = _make_mock_lora_ops()
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert result == []
def test_populated_cache_returns_configs(self):
"""When a cache entry exists, it appears in the output."""
cfg = _make_mock_config(
{"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4
)
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
entry = result[0]
assert entry["kernel"] == "scatter2scatter_lora_fwd"
assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024}
assert entry["config"]["BLOCK_N"] == 128
assert entry["config"]["BLOCK_K"] == 64
assert entry["config"]["num_warps"] == 8
assert entry["config"]["num_stages"] == 4
def test_multiple_kernels_and_keys(self):
"""Multiple cache entries across kernels are all returned."""
cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32})
cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8)
mock_lora_ops = _make_mock_lora_ops(
fwd_cache={(16, 256, 128): cfg_fwd},
dx_cache={(16, 256, 128): cfg_dx},
)
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 2
names = {r["kernel"] for r in result}
assert "scatter2scatter_lora_fwd" in names
assert "scatter2scatter_lora_dX" in names
def test_extra_key_elements_stored(self):
"""Dtype or other extra elements in the cache key are captured."""
cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32})
cache_key = (512, 1024, 256, "float16", "float16")
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
key = result[0]["key"]
assert key["M"] == 512
assert key["N"] == 1024
assert key["K"] == 256
assert key["_extra"] == ["float16", "float16"]
def test_no_module_in_sys_modules_returns_empty(self):
"""If no lora_ops module is loaded, return ``[]``."""
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
with patch(_FIND_MODULE_PATH, return_value=None):
result = collect_autotune_configs()
assert result == []
def test_finds_module_under_hash_suffixed_name(self):
"""Collector finds lora_ops regardless of the hash suffix."""
cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128})
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})
# Use a different hash to prove it's not hardcoded.
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
# Temporarily hide any real lora_ops modules that other tests in
# the same xdist worker may have loaded, so only our mock is found.
real_names = _real_lora_ops_module_names()
hide_patch = {name: None for name in real_names}
with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
assert result[0]["config"]["BLOCK_N"] == 256
# =========================================================================
# TestAutotuneReportCallback
# =========================================================================
class TestAutotuneReportCallback:
"""Test the callback fires once and sends the correct event."""
def test_reports_once_on_first_step(self):
"""Callback should call ``send_event`` exactly once."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}]
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
call_kwargs = mock_tm.send_event.call_args[1]
assert call_kwargs["event_type"] == "scattermoe-autotune"
assert call_kwargs["properties"]["kernel_count"] == 1
# Second call should NOT send again.
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_retries_until_step_5_then_gives_up(self):
"""If no configs found by step 5, stop retrying."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
with patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[],
):
for step in range(1, 7):
mock_state = MagicMock()
mock_state.global_step = step
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert cb._reported is True
def test_reports_on_retry_when_data_arrives(self):
"""If step 1 has no data but step 2 does, report at step 2."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
call_count = 0
def _collector():
nonlocal call_count
call_count += 1
if call_count == 1:
return []
return fake_configs
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
side_effect=_collector,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
# Step 1 — empty, no report
s1 = MagicMock()
s1.global_step = 1
cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Step 2 — data arrives, report
s2 = MagicMock()
s2.global_step = 2
cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_includes_gpu_info(self):
"""Event properties should include GPU identification."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
fake_gpu = {
"gpu_name": "NVIDIA H100",
"gpu_compute_capability": "9.0",
"gpu_memory_bytes": 85899345920,
}
fake_smem = {"smem_capacity_bytes": 233472}
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_gpu_info",
return_value=fake_gpu,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_smem_capacity",
return_value=fake_smem,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
props = mock_tm.send_event.call_args[1]["properties"]
assert props["gpu_name"] == "NVIDIA H100"
assert props["gpu_compute_capability"] == "9.0"
assert props["gpu_memory_bytes"] == 85899345920
assert props["smem_capacity_bytes"] == 233472
def test_skips_send_when_telemetry_disabled(self):
"""If telemetry is disabled, no event is sent."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[{"kernel": "fwd", "key": {}, "config": {}}],
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = False
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Should still mark as reported so we don't retry.
assert cb._reported is True
# =========================================================================
# TestKernelsPluginCallbackRegistration
# =========================================================================
class TestKernelsPluginCallbackRegistration:
"""Test that ``KernelsPlugin`` registers the callback correctly."""
def test_scattermoe_registers_callback(self):
"""When ``use_scattermoe=True``, plugin returns the callback."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = True
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert len(callbacks) == 1
assert isinstance(callbacks[0], AutotuneReportCallback)
def test_no_scattermoe_no_callback(self):
"""When ``use_scattermoe=False``, plugin returns empty list."""
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = False
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert callbacks == []