use new tf32 APIs for torch 2.9+ (#3467) [skip ci]
* use new tf32 APIs for torch 2.9+ * also upgrade cce for tf32 fixes and lint
This commit is contained in:
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
@@ -81,8 +84,15 @@ def resolve_dtype(cfg):
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
|
||||
if cfg.tf32:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if is_torch_greater_or_equal("2.9.0"):
|
||||
torch.backends.fp32_precision = "tf32"
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32"
|
||||
torch.backends.cudnn.fp32_precision = "tf32"
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user