let hf trainer handle torch compile (#516)

* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
Wing Lian
2023-09-13 11:42:12 -04:00
committed by GitHub
parent 36e53c7442
commit a4e1bb6606
3 changed files with 20 additions and 4 deletions

View File

@@ -80,10 +80,6 @@ def train(
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout