diff --git a/scripts/finetune.py b/scripts/finetune.py index 7c4d865fa..898f88c2c 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -183,6 +183,9 @@ def train( cfg.fp16 = True cfg.bf16 = False + if cfg.tf32: + torch.backends.cuda.matmul.allow_tf32 = True + # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config logging.info(f"loading tokenizer... {tokenizer_config}")