Set matmul tf32

This commit is contained in:
NanoCode012
2023-06-08 23:41:12 +09:00
parent 73e9ea4069
commit 52765ac588

View File

@@ -183,6 +183,9 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")