Merge pull request #163 from NanoCode012/feat/matmul-tf32

Feat: Set matmul tf32=True when tf32 passed
This commit is contained in:
NanoCode012
2023-06-09 00:01:31 +09:00
committed by GitHub

View File

@@ -183,6 +183,9 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")