From 9c1829cf575221045167825f02e34152c955131f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 16 Sep 2025 00:15:08 -0400 Subject: [PATCH] more logs --- src/axolotl/monkeypatch/moe_grouped.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py index 24408d03a..4bf4f6578 100644 --- a/src/axolotl/monkeypatch/moe_grouped.py +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -1,3 +1,4 @@ +import logging import warnings import torch @@ -5,6 +6,8 @@ import torch from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name +_LOG = logging.getLogger("axolotl.moe.patch") + def _patch_block_forward(block_cls, grouped_fn): """Replace block_cls.forward with grouped_fn preserving signature.""" @@ -20,13 +23,19 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None backend = get_moe_backend_name(preferred) if backend != MOEBackend.TORCH_GROUPED: + _LOG.info( + f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches" + ) return try: from axolotl.kernels.moe import torch_grouped as _tg except Exception: + _LOG.warning("torch_grouped backend import failed; skipping grouped patches") return if not _tg.available(): - warnings.warn("torch_grouped requested but unavailable; skipping MoE patches") + _LOG.warning( + "torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches" + ) return # Map of architecture key to (modeling module path, class name or list of class names) @@ -75,6 +84,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: return _grouped_forward + patched = 0 for key, (mod_path, cls_names) in model_mods.items(): if not cls_names: continue @@ -91,6 +101,14 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: if orig_forward is None: continue _patch_block_forward(block_cls, make_grouped_forward(orig_forward)) + patched += 1 + _LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}") except Exception as e: # Best effort; log and skip this entry - warnings.warn(f"Skipping MoE patch for {key}: {e}") + _LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}") + if patched == 0: + _LOG.warning( + "No MoE blocks patched for grouped GEMM; model may not use known MoE classes" + ) + else: + _LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")