more logs

This commit is contained in:
Dan Saunders
2025-09-16 00:15:08 -04:00
parent 135b09d1de
commit 9c1829cf57

View File

@@ -1,3 +1,4 @@
import logging
import warnings
import torch
@@ -5,6 +6,8 @@ import torch
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
_LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature."""
@@ -20,13 +23,19 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend != MOEBackend.TORCH_GROUPED:
_LOG.info(
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
)
return
try:
from axolotl.kernels.moe import torch_grouped as _tg
except Exception:
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
return
if not _tg.available():
warnings.warn("torch_grouped requested but unavailable; skipping MoE patches")
_LOG.warning(
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
)
return
# Map of architecture key to (modeling module path, class name or list of class names)
@@ -75,6 +84,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
return _grouped_forward
patched = 0
for key, (mod_path, cls_names) in model_mods.items():
if not cls_names:
continue
@@ -91,6 +101,14 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
if orig_forward is None:
continue
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
patched += 1
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
except Exception as e:
# Best effort; log and skip this entry
warnings.warn(f"Skipping MoE patch for {key}: {e}")
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
if patched == 0:
_LOG.warning(
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
)
else:
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")