error msg
This commit is contained in:
@@ -47,7 +47,6 @@ evals_per_epoch: 2
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# Enable router logits if you want aux loss/analysis
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
||||||
|
|
||||||
|
|
||||||
def available() -> bool:
|
def available() -> bool:
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +74,8 @@ def _call_grouped_mm(
|
|||||||
outs.append(Y_cat[start : start + m])
|
outs.append(Y_cat[start : start + m])
|
||||||
start += m
|
start += m
|
||||||
return outs
|
return outs
|
||||||
except RuntimeError:
|
except RuntimeError as err:
|
||||||
|
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user