let hf trainer handle torch compile (#516)

* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
Wing Lian
2023-09-13 11:42:12 -04:00
committed by GitHub
parent 36e53c7442
commit a4e1bb6606
3 changed files with 20 additions and 4 deletions

View File

@@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
# training hyperparameters # training hyperparameters
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2

View File

@@ -80,10 +80,6 @@ def train(
model.config.use_cache = False model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
import torch
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.greater_is_better: if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config # DDP Config
if cfg.ddp_timeout: if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout