use new tf32 APIs for torch 2.9+ (#3467) [skip ci]

* use new tf32 APIs for torch 2.9+

* also upgrade cce for tf32 fixes and lint
This commit is contained in:
Wing Lian
2026-03-06 11:40:32 -05:00
committed by GitHub
parent c119382337
commit fc2d63ee5f
6 changed files with 18 additions and 7 deletions

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
)

View File

@@ -6,7 +6,10 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from transformers.utils.import_utils import (
is_torch_greater_or_equal,
is_torch_npu_available,
)
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
@@ -81,8 +84,15 @@ def resolve_dtype(cfg):
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
if cfg.tf32:
torch.set_float32_matmul_precision("high")
if is_torch_greater_or_equal("2.9.0"):
torch.backends.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if cfg.bf16:
cfg.fp16 = False