let hf trainer handle torch compile (#516)
* let hf trainer handle torch compile * remove torch compile checks, include option for backend * suppress torch errors to get further * require min torch version of 2.1.0 for torch compile to work --------- Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
@@ -80,10 +80,6 @@ def train(
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
LOG.info("Compiling torch model")
|
||||
model = torch.compile(model)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||
|
||||
if cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
else:
|
||||
import torch._dynamo # pylint: disable=redefined-outer-name
|
||||
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
||||
if cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
||||
|
||||
Reference in New Issue
Block a user