fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci]
* fix for flaky tests in lora ops kernels w autotune * attempt 2 to fix
This commit is contained in:
@@ -55,7 +55,7 @@ def _find_lora_ops_module() -> ModuleType | None:
|
||||
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
|
||||
attribute — that is the runtime copy with populated caches.
|
||||
"""
|
||||
for name, module in sys.modules.items():
|
||||
for name, module in list(sys.modules.items()):
|
||||
if (
|
||||
module is not None
|
||||
and "lora_ops" in name
|
||||
|
||||
@@ -15,6 +15,11 @@ from unittest.mock import MagicMock, patch
|
||||
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
|
||||
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
|
||||
|
||||
# Patch target for _find_lora_ops_module inside the collector module.
|
||||
_FIND_MODULE_PATH = (
|
||||
"axolotl.integrations.kernels.autotune_collector._find_lora_ops_module"
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
|
||||
"""Create a mock triton.Config-like object."""
|
||||
@@ -41,19 +46,40 @@ def _make_mock_lora_ops(
|
||||
return mod
|
||||
|
||||
|
||||
def _real_lora_ops_module_names():
|
||||
"""Return sys.modules keys that match the lora_ops discovery pattern.
|
||||
|
||||
Other tests in the same xdist worker may have loaded the *real*
|
||||
lora_ops module. We need to temporarily hide those entries so the
|
||||
discovery test finds only the mock we inject.
|
||||
"""
|
||||
return [
|
||||
name
|
||||
for name, mod in list(sys.modules.items())
|
||||
if mod is not None
|
||||
and "lora_ops" in name
|
||||
and hasattr(mod, "_scatter2scatter_lora")
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestAutotuneCollector
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAutotuneCollector:
|
||||
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
|
||||
"""Test ``collect_autotune_configs`` with mocked kernel objects.
|
||||
|
||||
Collection tests patch ``_find_lora_ops_module`` directly so they are
|
||||
not affected by real ``lora_ops`` modules that other tests in the same
|
||||
pytest-xdist worker may have loaded into ``sys.modules``.
|
||||
"""
|
||||
|
||||
def test_empty_cache_returns_empty_list(self):
|
||||
"""When no kernel has been autotuned yet, return ``[]``."""
|
||||
mock_lora_ops = _make_mock_lora_ops()
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
@@ -68,7 +94,7 @@ class TestAutotuneCollector:
|
||||
)
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
@@ -94,7 +120,7 @@ class TestAutotuneCollector:
|
||||
dx_cache={(16, 256, 128): cfg_dx},
|
||||
)
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
@@ -113,7 +139,7 @@ class TestAutotuneCollector:
|
||||
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
@@ -133,9 +159,8 @@ class TestAutotuneCollector:
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
# Don't inject anything — the real lora_ops isn't loaded either
|
||||
# (no triton on this machine), so _find_lora_ops_module returns None.
|
||||
result = collect_autotune_configs()
|
||||
with patch(_FIND_MODULE_PATH, return_value=None):
|
||||
result = collect_autotune_configs()
|
||||
assert result == []
|
||||
|
||||
def test_finds_module_under_hash_suffixed_name(self):
|
||||
@@ -145,7 +170,13 @@ class TestAutotuneCollector:
|
||||
|
||||
# Use a different hash to prove it's not hardcoded.
|
||||
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
|
||||
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
|
||||
|
||||
# Temporarily hide any real lora_ops modules that other tests in
|
||||
# the same xdist worker may have loaded, so only our mock is found.
|
||||
real_names = _real_lora_ops_module_names()
|
||||
hide_patch = {name: None for name in real_names}
|
||||
|
||||
with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user