let hf trainer handle torch compile (#516)
* let hf trainer handle torch compile * remove torch compile checks, include option for backend * suppress torch errors to get further * require min torch version of 2.1.0 for torch compile to work --------- Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
@@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|||||||
# where to save the finished model to
|
# where to save the finished model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
|
# whether to use torch.compile and which backend to use
|
||||||
|
torch_compile: # bool
|
||||||
|
torch_compile_backend: # Optional[str]
|
||||||
|
|
||||||
# training hyperparameters
|
# training hyperparameters
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
|
|||||||
@@ -80,10 +80,6 @@ def train(
|
|||||||
|
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
|
||||||
LOG.info("Compiling torch model")
|
|
||||||
model = torch.compile(model)
|
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, set_caching_enabled
|
from datasets import Dataset, set_caching_enabled
|
||||||
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.greater_is_better:
|
if cfg.greater_is_better:
|
||||||
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||||
|
|
||||||
|
if cfg.torch_compile:
|
||||||
|
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||||
|
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||||
|
else:
|
||||||
|
import torch._dynamo # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||||
|
True
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
||||||
|
if cfg.torch_compile_backend:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"torch_compile_backend"
|
||||||
|
] = cfg.torch_compile_backend
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if cfg.ddp_timeout:
|
if cfg.ddp_timeout:
|
||||||
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
||||||
|
|||||||
Reference in New Issue
Block a user