"""Tests for scattermoe autotune telemetry integration. These tests use mocking to verify the collection and reporting logic without requiring Triton or CUDA. """ import sys from types import SimpleNamespace from unittest.mock import MagicMock, patch # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- # Simulate the hash-suffixed module name that LocalLayerRepository creates. _FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops" # Patch target for _find_lora_ops_module inside the collector module. _FIND_MODULE_PATH = ( "axolotl.integrations.kernels.autotune_collector._find_lora_ops_module" ) def _make_mock_config(kwargs, num_warps=4, num_stages=3): """Create a mock triton.Config-like object.""" return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages) def _make_mock_kernel(cache=None): """Create a mock autotuned kernel object with a ``.cache`` dict.""" kernel = SimpleNamespace() kernel.cache = cache if cache is not None else {} return kernel def _make_mock_lora_ops( fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None ): """Build a mock ``lora_ops`` module with the four kernel attributes.""" mod = SimpleNamespace( _scatter2scatter_lora=_make_mock_kernel(fwd_cache), _scatter2scatter_lora_dX=_make_mock_kernel(dx_cache), _group_bwd_lora=_make_mock_kernel(bwd_cache), _group_bwd_lora_fused=_make_mock_kernel(fused_cache), ) return mod def _real_lora_ops_module_names(): """Return sys.modules keys that match the lora_ops discovery pattern. Other tests in the same xdist worker may have loaded the *real* lora_ops module. We need to temporarily hide those entries so the discovery test finds only the mock we inject. """ return [ name for name, mod in list(sys.modules.items()) if mod is not None and "lora_ops" in name and hasattr(mod, "_scatter2scatter_lora") ] # ========================================================================= # TestAutotuneCollector # ========================================================================= class TestAutotuneCollector: """Test ``collect_autotune_configs`` with mocked kernel objects. Collection tests patch ``_find_lora_ops_module`` directly so they are not affected by real ``lora_ops`` modules that other tests in the same pytest-xdist worker may have loaded into ``sys.modules``. """ def test_empty_cache_returns_empty_list(self): """When no kernel has been autotuned yet, return ``[]``.""" mock_lora_ops = _make_mock_lora_ops() with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) result = collect_autotune_configs() assert result == [] def test_populated_cache_returns_configs(self): """When a cache entry exists, it appears in the output.""" cfg = _make_mock_config( {"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4 ) mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg}) with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) result = collect_autotune_configs() assert len(result) == 1 entry = result[0] assert entry["kernel"] == "scatter2scatter_lora_fwd" assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024} assert entry["config"]["BLOCK_N"] == 128 assert entry["config"]["BLOCK_K"] == 64 assert entry["config"]["num_warps"] == 8 assert entry["config"]["num_stages"] == 4 def test_multiple_kernels_and_keys(self): """Multiple cache entries across kernels are all returned.""" cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32}) cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8) mock_lora_ops = _make_mock_lora_ops( fwd_cache={(16, 256, 128): cfg_fwd}, dx_cache={(16, 256, 128): cfg_dx}, ) with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) result = collect_autotune_configs() assert len(result) == 2 names = {r["kernel"] for r in result} assert "scatter2scatter_lora_fwd" in names assert "scatter2scatter_lora_dX" in names def test_extra_key_elements_stored(self): """Dtype or other extra elements in the cache key are captured.""" cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32}) cache_key = (512, 1024, 256, "float16", "float16") mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg}) with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) result = collect_autotune_configs() assert len(result) == 1 key = result[0]["key"] assert key["M"] == 512 assert key["N"] == 1024 assert key["K"] == 256 assert key["_extra"] == ["float16", "float16"] def test_no_module_in_sys_modules_returns_empty(self): """If no lora_ops module is loaded, return ``[]``.""" from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) with patch(_FIND_MODULE_PATH, return_value=None): result = collect_autotune_configs() assert result == [] def test_finds_module_under_hash_suffixed_name(self): """Collector finds lora_ops regardless of the hash suffix.""" cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128}) mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg}) # Use a different hash to prove it's not hardcoded. alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops" # Temporarily hide any real lora_ops modules that other tests in # the same xdist worker may have loaded, so only our mock is found. real_names = _real_lora_ops_module_names() hide_patch = {name: None for name in real_names} with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) result = collect_autotune_configs() assert len(result) == 1 assert result[0]["config"]["BLOCK_N"] == 256 # ========================================================================= # TestAutotuneReportCallback # ========================================================================= class TestAutotuneReportCallback: """Test the callback fires once and sends the correct event.""" def test_reports_once_on_first_step(self): """Callback should call ``send_event`` exactly once.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) cb = AutotuneReportCallback() mock_state = MagicMock() mock_state.global_step = 1 fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}] with ( patch( "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", return_value=fake_configs, ), patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, ): mock_tm = MagicMock() mock_tm.enabled = True mock_tm_cls.get_instance.return_value = mock_tm cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) assert mock_tm.send_event.call_count == 1 call_kwargs = mock_tm.send_event.call_args[1] assert call_kwargs["event_type"] == "scattermoe-autotune" assert call_kwargs["properties"]["kernel_count"] == 1 # Second call should NOT send again. cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) assert mock_tm.send_event.call_count == 1 def test_retries_until_step_5_then_gives_up(self): """If no configs found by step 5, stop retrying.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) cb = AutotuneReportCallback() with patch( "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", return_value=[], ): for step in range(1, 7): mock_state = MagicMock() mock_state.global_step = step cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) assert cb._reported is True def test_reports_on_retry_when_data_arrives(self): """If step 1 has no data but step 2 does, report at step 2.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) cb = AutotuneReportCallback() fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}] call_count = 0 def _collector(): nonlocal call_count call_count += 1 if call_count == 1: return [] return fake_configs with ( patch( "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", side_effect=_collector, ), patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, ): mock_tm = MagicMock() mock_tm.enabled = True mock_tm_cls.get_instance.return_value = mock_tm # Step 1 — empty, no report s1 = MagicMock() s1.global_step = 1 cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock()) assert mock_tm.send_event.call_count == 0 # Step 2 — data arrives, report s2 = MagicMock() s2.global_step = 2 cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock()) assert mock_tm.send_event.call_count == 1 def test_includes_gpu_info(self): """Event properties should include GPU identification.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) cb = AutotuneReportCallback() mock_state = MagicMock() mock_state.global_step = 1 fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}] fake_gpu = { "gpu_name": "NVIDIA H100", "gpu_compute_capability": "9.0", "gpu_memory_bytes": 85899345920, } fake_smem = {"smem_capacity_bytes": 233472} with ( patch( "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", return_value=fake_configs, ), patch( "axolotl.integrations.kernels.autotune_callback._get_gpu_info", return_value=fake_gpu, ), patch( "axolotl.integrations.kernels.autotune_callback._get_smem_capacity", return_value=fake_smem, ), patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, ): mock_tm = MagicMock() mock_tm.enabled = True mock_tm_cls.get_instance.return_value = mock_tm cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) props = mock_tm.send_event.call_args[1]["properties"] assert props["gpu_name"] == "NVIDIA H100" assert props["gpu_compute_capability"] == "9.0" assert props["gpu_memory_bytes"] == 85899345920 assert props["smem_capacity_bytes"] == 233472 def test_skips_send_when_telemetry_disabled(self): """If telemetry is disabled, no event is sent.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) cb = AutotuneReportCallback() mock_state = MagicMock() mock_state.global_step = 1 with ( patch( "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", return_value=[{"kernel": "fwd", "key": {}, "config": {}}], ), patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, ): mock_tm = MagicMock() mock_tm.enabled = False mock_tm_cls.get_instance.return_value = mock_tm cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) assert mock_tm.send_event.call_count == 0 # Should still mark as reported so we don't retry. assert cb._reported is True # ========================================================================= # TestKernelsPluginCallbackRegistration # ========================================================================= class TestKernelsPluginCallbackRegistration: """Test that ``KernelsPlugin`` registers the callback correctly.""" def test_scattermoe_registers_callback(self): """When ``use_scattermoe=True``, plugin returns the callback.""" from axolotl.integrations.kernels.autotune_callback import ( AutotuneReportCallback, ) from axolotl.integrations.kernels.plugin import KernelsPlugin plugin = KernelsPlugin() cfg = MagicMock() cfg.use_scattermoe = True model = MagicMock() callbacks = plugin.add_callbacks_pre_trainer(cfg, model) assert len(callbacks) == 1 assert isinstance(callbacks[0], AutotuneReportCallback) def test_no_scattermoe_no_callback(self): """When ``use_scattermoe=False``, plugin returns empty list.""" from axolotl.integrations.kernels.plugin import KernelsPlugin plugin = KernelsPlugin() cfg = MagicMock() cfg.use_scattermoe = False model = MagicMock() callbacks = plugin.add_callbacks_pre_trainer(cfg, model) assert callbacks == []