diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 634575066..b527dce08 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -78,6 +78,7 @@ def resolve_dtype(cfg): cfg.bf16 = False else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False + torch.backends.cudnn.allow_tf32 = cfg.tf32 or False if cfg.bf16: cfg.fp16 = False