use new tf32 APIs for torch 2.9+ (#3467) [skip ci]

* use new tf32 APIs for torch 2.9+

* also upgrade cce for tf32 fixes and lint
This commit is contained in:
Wing Lian
2026-03-06 11:40:32 -05:00
committed by GitHub
parent c119382337
commit fc2d63ee5f
6 changed files with 18 additions and 7 deletions

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
] ]
}, },
{ {

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
) )

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
``` ```
## Usage ## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
) )

View File

@@ -6,7 +6,10 @@ from typing import Optional
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import (
is_torch_greater_or_equal,
is_torch_npu_available,
)
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args from axolotl.integrations.config import merge_input_args
@@ -81,8 +84,15 @@ def resolve_dtype(cfg):
cfg.fp16 = True cfg.fp16 = True
cfg.bf16 = False cfg.bf16 = False
else: else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False if cfg.tf32:
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False torch.set_float32_matmul_precision("high")
if is_torch_greater_or_equal("2.9.0"):
torch.backends.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if cfg.bf16: if cfg.bf16:
cfg.fp16 = False cfg.fp16 = False

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
import addict import addict
import pytest import pytest
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault