more logs
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -5,6 +6,8 @@ import torch
|
|||||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||||
|
|
||||||
|
_LOG = logging.getLogger("axolotl.moe.patch")
|
||||||
|
|
||||||
|
|
||||||
def _patch_block_forward(block_cls, grouped_fn):
|
def _patch_block_forward(block_cls, grouped_fn):
|
||||||
"""Replace block_cls.forward with grouped_fn preserving signature."""
|
"""Replace block_cls.forward with grouped_fn preserving signature."""
|
||||||
@@ -20,13 +23,19 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||||
backend = get_moe_backend_name(preferred)
|
backend = get_moe_backend_name(preferred)
|
||||||
if backend != MOEBackend.TORCH_GROUPED:
|
if backend != MOEBackend.TORCH_GROUPED:
|
||||||
|
_LOG.info(
|
||||||
|
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
from axolotl.kernels.moe import torch_grouped as _tg
|
from axolotl.kernels.moe import torch_grouped as _tg
|
||||||
except Exception:
|
except Exception:
|
||||||
|
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
|
||||||
return
|
return
|
||||||
if not _tg.available():
|
if not _tg.available():
|
||||||
warnings.warn("torch_grouped requested but unavailable; skipping MoE patches")
|
_LOG.warning(
|
||||||
|
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Map of architecture key to (modeling module path, class name or list of class names)
|
# Map of architecture key to (modeling module path, class name or list of class names)
|
||||||
@@ -75,6 +84,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
|
|
||||||
return _grouped_forward
|
return _grouped_forward
|
||||||
|
|
||||||
|
patched = 0
|
||||||
for key, (mod_path, cls_names) in model_mods.items():
|
for key, (mod_path, cls_names) in model_mods.items():
|
||||||
if not cls_names:
|
if not cls_names:
|
||||||
continue
|
continue
|
||||||
@@ -91,6 +101,14 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
if orig_forward is None:
|
if orig_forward is None:
|
||||||
continue
|
continue
|
||||||
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
|
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
|
||||||
|
patched += 1
|
||||||
|
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Best effort; log and skip this entry
|
# Best effort; log and skip this entry
|
||||||
warnings.warn(f"Skipping MoE patch for {key}: {e}")
|
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
|
||||||
|
if patched == 0:
|
||||||
|
_LOG.warning(
|
||||||
|
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")
|
||||||
|
|||||||
Reference in New Issue
Block a user