error msg

This commit is contained in:
Dan Saunders
2025-09-17 18:29:30 -04:00
parent b5cb345ca4
commit d2b49b2670
2 changed files with 5 additions and 2 deletions

View File

@@ -47,7 +47,6 @@ evals_per_epoch: 2
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# Enable router logits if you want aux loss/analysis
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -2,11 +2,14 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
_LOGGER = logging.getLogger("axolotl.moe.grouped")
def available() -> bool: def available() -> bool:
try: try:
@@ -71,7 +74,8 @@ def _call_grouped_mm(
outs.append(Y_cat[start : start + m]) outs.append(Y_cat[start : start + m])
start += m start += m
return outs return outs
except RuntimeError: except RuntimeError as err:
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
return None return None