error msg
This commit is contained in:
@@ -47,7 +47,6 @@ evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# Enable router logits if you want aux loss/analysis
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
||||
|
||||
|
||||
def available() -> bool:
|
||||
try:
|
||||
@@ -71,7 +74,8 @@ def _call_grouped_mm(
|
||||
outs.append(Y_cat[start : start + m])
|
||||
start += m
|
||||
return outs
|
||||
except RuntimeError:
|
||||
except RuntimeError as err:
|
||||
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user