fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci]

* fix for flaky tests in lora ops kernels w autotune

* attempt 2 to fix
This commit is contained in:
Wing Lian
2026-03-19 01:18:47 -04:00
committed by GitHub
parent 5ef3f28340
commit f291ac029c
2 changed files with 41 additions and 10 deletions

View File

@@ -55,7 +55,7 @@ def _find_lora_ops_module() -> ModuleType | None:
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
attribute — that is the runtime copy with populated caches.
"""
for name, module in sys.modules.items():
for name, module in list(sys.modules.items()):
if (
module is not None
and "lora_ops" in name

View File

@@ -15,6 +15,11 @@ from unittest.mock import MagicMock, patch
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
# Patch target for _find_lora_ops_module inside the collector module.
_FIND_MODULE_PATH = (
"axolotl.integrations.kernels.autotune_collector._find_lora_ops_module"
)
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
"""Create a mock triton.Config-like object."""
@@ -41,19 +46,40 @@ def _make_mock_lora_ops(
return mod
def _real_lora_ops_module_names():
"""Return sys.modules keys that match the lora_ops discovery pattern.
Other tests in the same xdist worker may have loaded the *real*
lora_ops module. We need to temporarily hide those entries so the
discovery test finds only the mock we inject.
"""
return [
name
for name, mod in list(sys.modules.items())
if mod is not None
and "lora_ops" in name
and hasattr(mod, "_scatter2scatter_lora")
]
# =========================================================================
# TestAutotuneCollector
# =========================================================================
class TestAutotuneCollector:
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
"""Test ``collect_autotune_configs`` with mocked kernel objects.
Collection tests patch ``_find_lora_ops_module`` directly so they are
not affected by real ``lora_ops`` modules that other tests in the same
pytest-xdist worker may have loaded into ``sys.modules``.
"""
def test_empty_cache_returns_empty_list(self):
"""When no kernel has been autotuned yet, return ``[]``."""
mock_lora_ops = _make_mock_lora_ops()
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
@@ -68,7 +94,7 @@ class TestAutotuneCollector:
)
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
@@ -94,7 +120,7 @@ class TestAutotuneCollector:
dx_cache={(16, 256, 128): cfg_dx},
)
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
@@ -113,7 +139,7 @@ class TestAutotuneCollector:
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
@@ -133,9 +159,8 @@ class TestAutotuneCollector:
collect_autotune_configs,
)
# Don't inject anything — the real lora_ops isn't loaded either
# (no triton on this machine), so _find_lora_ops_module returns None.
result = collect_autotune_configs()
with patch(_FIND_MODULE_PATH, return_value=None):
result = collect_autotune_configs()
assert result == []
def test_finds_module_under_hash_suffixed_name(self):
@@ -145,7 +170,13 @@ class TestAutotuneCollector:
# Use a different hash to prove it's not hardcoded.
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
# Temporarily hide any real lora_ops modules that other tests in
# the same xdist worker may have loaded, so only our mock is found.
real_names = _real_lora_ops_module_names()
hide_patch = {name: None for name in real_names}
with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)