diff --git a/examples/moe/qwen2-moe-qlora-10gb.yaml b/examples/moe/qwen2-moe-qlora-10gb.yaml index a364f8647..fc330c50e 100644 --- a/examples/moe/qwen2-moe-qlora-10gb.yaml +++ b/examples/moe/qwen2-moe-qlora-10gb.yaml @@ -47,7 +47,6 @@ evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 -# Enable router logits if you want aux loss/analysis model_config: output_router_logits: true diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index aea1063b7..90d6e9828 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -2,11 +2,14 @@ from __future__ import annotations +import logging from typing import List, Optional, Tuple import torch import torch.nn.functional as F +_LOGGER = logging.getLogger("axolotl.moe.grouped") + def available() -> bool: try: @@ -71,7 +74,8 @@ def _call_grouped_mm( outs.append(Y_cat[start : start + m]) start += m return outs - except RuntimeError: + except RuntimeError as err: + _LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err) return None