fix compile

This commit is contained in:
Dan Saunders
2025-09-19 13:55:54 -04:00
parent b5dc58373f
commit ce21da9177
2 changed files with 11 additions and 5 deletions

View File

@@ -13,6 +13,7 @@ from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
@@ -163,6 +164,8 @@ def main() -> None:
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: