From 23ac14540b286e3c9dfecc8a8d47e96891542bf8 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 15:46:31 -0500 Subject: [PATCH] Basic evaluate CLI command / codepath (#2188) * basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders --- outputs | 1 + src/axolotl/train.py | 2 +- src/axolotl/utils/trainer.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 120000 outputs diff --git a/outputs b/outputs new file mode 120000 index 000000000..be3c4a823 --- /dev/null +++ b/outputs @@ -0,0 +1 @@ +/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b09..c5cf7ad6d 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -26,7 +26,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer try: from optimum.bettertransformer import BetterTransformer diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a8..fd09b3eb6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" +def set_pytorch_cuda_alloc_conf(): + """Set up CUDA allocation config if using PyTorch >= 2.2""" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ):