add additional tf32 opt for cudnn (#2477) [skip ci]

This commit is contained in:
Wing Lian
2025-04-03 08:47:52 -04:00
committed by GitHub
parent 3877c5c69d
commit 5249e98058

View File

@@ -78,6 +78,7 @@ def resolve_dtype(cfg):
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
cfg.fp16 = False