From 6cdcb8ddd5fca39d6f9b32285868eb4f5a869406 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 26 Mar 2025 18:14:43 -0400 Subject: [PATCH] Set the pytorch_cuda_alloc_conf env in the train module (#2447) --- src/axolotl/cli/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 6cc7c7701..e225141b6 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train +from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault @@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Training-specific CLI arguments. """ + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + print_axolotl_text_art() check_accelerate_default_config() if int(os.getenv("LOCAL_RANK", "0")) == 0: