diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5de1bc114..8e718af9b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -53,11 +53,11 @@ def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # enable expandable segments for cuda allocation to improve VRAM usage - # torch_version = torch.__version__.split(".") - # torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - # if torch_major == 2 and torch_minor >= 2: - # if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # load the tokenizer first LOG.debug(