Fix(misc): address PYTORCH_CUDA_ALLOC_CONF deprecate (#3313)

* fix: leftover ministral docs changes

* fix: pytorch_cuda_alloc_conf deprecation

* fix: set old PYTORCH_CUDA_ALLOC_CONF env too

* handle 2.9 separately

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-12-17 21:12:18 +07:00
committed by GitHub
parent 2a664dc8ad
commit a1d07f42e4
4 changed files with 16 additions and 9 deletions

View File

@@ -41,14 +41,22 @@ def get_pytorch_version() -> tuple[int, int, int]:
def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
"""Set up CUDA allocation config"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
config_value = "expandable_segments:True,roundup_power2_divisions:16"
if (
torch_major == 2
and torch_minor >= 9
and os.getenv("PYTORCH_ALLOC_CONF") is None
):
os.environ["PYTORCH_ALLOC_CONF"] = config_value
elif (
torch_major == 2
and torch_minor >= 2
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
def set_misc_env():