From fc2d63ee5f7bdb6c56bd0f70888aec75db068db3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Mar 2026 11:40:32 -0500 Subject: [PATCH] use new tf32 APIs for torch 2.9+ (#3467) [skip ci] * use new tf32 APIs for torch 2.9+ * also upgrade cce for tf32 fixes and lint --- .../colab-notebooks/colab-axolotl-example.ipynb | 2 +- scripts/cutcrossentropy_install.py | 2 +- .../integrations/cut_cross_entropy/README.md | 2 +- .../integrations/cut_cross_entropy/__init__.py | 2 +- src/axolotl/utils/config/__init__.py | 16 +++++++++++++--- tests/test_tensor_parallel_batch_size.py | 1 + 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 2cc27f211..7be9800be 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\"" ] }, { diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index f6cd0c495..d506fa87e 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index b892033da..9520dd48c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129" ``` ## Usage diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 5c207e0fc..d8aa075b9 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`' ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 07c4d175f..e8ca72aa1 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -6,7 +6,10 @@ from typing import Optional import torch from transformers.utils import is_torch_bf16_gpu_available -from transformers.utils.import_utils import is_torch_npu_available +from transformers.utils.import_utils import ( + is_torch_greater_or_equal, + is_torch_npu_available, +) from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args @@ -81,8 +84,15 @@ def resolve_dtype(cfg): cfg.fp16 = True cfg.bf16 = False else: - torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False - torch.backends.cudnn.allow_tf32 = cfg.tf32 or False + if cfg.tf32: + torch.set_float32_matmul_precision("high") + if is_torch_greater_or_equal("2.9.0"): + torch.backends.fp32_precision = "tf32" + torch.backends.cuda.matmul.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "tf32" + else: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True if cfg.bf16: cfg.fp16 = False diff --git a/tests/test_tensor_parallel_batch_size.py b/tests/test_tensor_parallel_batch_size.py index f0b27a8eb..c6a8174fc 100644 --- a/tests/test_tensor_parallel_batch_size.py +++ b/tests/test_tensor_parallel_batch_size.py @@ -4,6 +4,7 @@ from unittest.mock import patch import addict import pytest + from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault