diff --git a/README.md b/README.md index 176894bdd..0ba61eb58 100644 --- a/README.md +++ b/README.md @@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # where to save the finished model to output_dir: ./completed-model +# whether to use torch.compile and which backend to use +torch_compile: # bool +torch_compile_backend: # Optional[str] + # training hyperparameters gradient_accumulation_steps: 1 micro_batch_size: 2 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b1be8f8a3..fa6dbceaf 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -80,10 +80,6 @@ def train( model.config.use_cache = False - if torch.__version__ >= "2" and sys.platform != "win32": - LOG.info("Compiling torch model") - model = torch.compile(model) - # go ahead and presave, so we have the adapter config available to inspect if peft_config: LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9685176b6..946505c22 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Optional, Union import numpy as np +import torch import torch.cuda import transformers from datasets import Dataset, set_caching_enabled @@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ if cfg.greater_is_better: training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better + if cfg.torch_compile: + if torch.__version__ < "2.1.0": # pylint: disable=protected-access + LOG.warning("torch>=2.1.0 required for torch_compile to work properly") + else: + import torch._dynamo # pylint: disable=redefined-outer-name + + torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access + True + ) + training_arguments_kwargs["torch_compile"] = cfg.torch_compile + if cfg.torch_compile_backend: + training_arguments_kwargs[ + "torch_compile_backend" + ] = cfg.torch_compile_backend + # DDP Config if cfg.ddp_timeout: training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout