From f291ac029c630b6685566c610092b87bfe73b951 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 19 Mar 2026 01:18:47 -0400 Subject: [PATCH] fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci] * fix for flaky tests in lora ops kernels w autotune * attempt 2 to fix --- .../kernels/autotune_collector.py | 2 +- .../test_scattermoe_autotune_telemetry.py | 49 +++++++++++++++---- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/axolotl/integrations/kernels/autotune_collector.py b/src/axolotl/integrations/kernels/autotune_collector.py index ef4111dcf..bdb5e030e 100644 --- a/src/axolotl/integrations/kernels/autotune_collector.py +++ b/src/axolotl/integrations/kernels/autotune_collector.py @@ -55,7 +55,7 @@ def _find_lora_ops_module() -> ModuleType | None: ``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel attribute — that is the runtime copy with populated caches. """ - for name, module in sys.modules.items(): + for name, module in list(sys.modules.items()): if ( module is not None and "lora_ops" in name diff --git a/tests/integrations/test_scattermoe_autotune_telemetry.py b/tests/integrations/test_scattermoe_autotune_telemetry.py index 50ac56720..7050c0f4f 100644 --- a/tests/integrations/test_scattermoe_autotune_telemetry.py +++ b/tests/integrations/test_scattermoe_autotune_telemetry.py @@ -15,6 +15,11 @@ from unittest.mock import MagicMock, patch # Simulate the hash-suffixed module name that LocalLayerRepository creates. _FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops" +# Patch target for _find_lora_ops_module inside the collector module. +_FIND_MODULE_PATH = ( + "axolotl.integrations.kernels.autotune_collector._find_lora_ops_module" +) + def _make_mock_config(kwargs, num_warps=4, num_stages=3): """Create a mock triton.Config-like object.""" @@ -41,19 +46,40 @@ def _make_mock_lora_ops( return mod +def _real_lora_ops_module_names(): + """Return sys.modules keys that match the lora_ops discovery pattern. + + Other tests in the same xdist worker may have loaded the *real* + lora_ops module. We need to temporarily hide those entries so the + discovery test finds only the mock we inject. + """ + return [ + name + for name, mod in list(sys.modules.items()) + if mod is not None + and "lora_ops" in name + and hasattr(mod, "_scatter2scatter_lora") + ] + + # ========================================================================= # TestAutotuneCollector # ========================================================================= class TestAutotuneCollector: - """Test ``collect_autotune_configs`` with mocked kernel objects.""" + """Test ``collect_autotune_configs`` with mocked kernel objects. + + Collection tests patch ``_find_lora_ops_module`` directly so they are + not affected by real ``lora_ops`` modules that other tests in the same + pytest-xdist worker may have loaded into ``sys.modules``. + """ def test_empty_cache_returns_empty_list(self): """When no kernel has been autotuned yet, return ``[]``.""" mock_lora_ops = _make_mock_lora_ops() - with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) @@ -68,7 +94,7 @@ class TestAutotuneCollector: ) mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg}) - with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) @@ -94,7 +120,7 @@ class TestAutotuneCollector: dx_cache={(16, 256, 128): cfg_dx}, ) - with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) @@ -113,7 +139,7 @@ class TestAutotuneCollector: mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg}) - with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, ) @@ -133,9 +159,8 @@ class TestAutotuneCollector: collect_autotune_configs, ) - # Don't inject anything — the real lora_ops isn't loaded either - # (no triton on this machine), so _find_lora_ops_module returns None. - result = collect_autotune_configs() + with patch(_FIND_MODULE_PATH, return_value=None): + result = collect_autotune_configs() assert result == [] def test_finds_module_under_hash_suffixed_name(self): @@ -145,7 +170,13 @@ class TestAutotuneCollector: # Use a different hash to prove it's not hardcoded. alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops" - with patch.dict(sys.modules, {alt_name: mock_lora_ops}): + + # Temporarily hide any real lora_ops modules that other tests in + # the same xdist worker may have loaded, so only our mock is found. + real_names = _real_lora_ops_module_names() + hide_patch = {name: None for name in real_names} + + with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}): from axolotl.integrations.kernels.autotune_collector import ( collect_autotune_configs, )