diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index f81ba0b2e..568c115cc 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -11,7 +11,7 @@ from urllib.parse import urlparse import requests import torch import yaml -from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available from axolotl.integrations.base import PluginManager from axolotl.telemetry.errors import send_errors @@ -310,6 +310,7 @@ def load_cfg( capabilities={ "bf16": is_torch_bf16_gpu_available(), "fp8": compute_supports_fp8(), + "tf32": is_torch_tf32_available(), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index c23433866..a149566b3 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC): def _configure_precision_settings(self, training_args_kwargs: dict): training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False - training_args_kwargs["tf32"] = self.cfg.tf32 + training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False if self.cfg.bf16 == "full": training_args_kwargs["bf16_full_eval"] = True else: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b779abaa6..61096cb86 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -84,7 +84,7 @@ def resolve_dtype(cfg): cfg.fp16 = True cfg.bf16 = False else: - if cfg.tf32: + if cfg.tf32 is True: torch.set_float32_matmul_precision("high") if is_torch_greater_or_equal("2.9.0"): torch.backends.fp32_precision = "tf32" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 5ea340c37..a4eadf5cf 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -407,9 +407,11 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "No AMP (automatic mixed precision)"}, ) # for non-AMP cases - tf32: bool | None = Field( - default=None, - json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"}, + tf32: Literal["auto"] | bool | None = Field( + default="auto", + json_schema_extra={ + "description": "bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere" + }, ) float32: bool | None = None @@ -1218,6 +1220,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self + @model_validator(mode="after") + def check_tf32(self): + if self.tf32 == "auto": + self.tf32 = self.capabilities.tf32 + return self + @model_validator(mode="after") def check_fp8(self): if self.fp8 and not self.capabilities.fp8: diff --git a/src/axolotl/utils/schemas/internal/__init__.py b/src/axolotl/utils/schemas/internal/__init__.py index 692dee833..78cc636db 100644 --- a/src/axolotl/utils/schemas/internal/__init__.py +++ b/src/axolotl/utils/schemas/internal/__init__.py @@ -10,6 +10,7 @@ class GPUCapabilities(BaseModel): bf16: bool = Field(default=False) fp8: bool = Field(default=False) + tf32: bool = Field(default=False) n_gpu: int = Field(default=1) n_node: int = Field(default=1) compute_capability: Optional[str] = Field(default=None) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 795b0de37..47a791577 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -2,6 +2,8 @@ E2E tests for llama """ +import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -143,7 +145,8 @@ class TestLlama: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - def test_batch_flattening(self, temp_dir): + @pytest.mark.parametrize("tf32", ["auto", False]) + def test_batch_flattening(self, tf32, temp_dir): cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -171,6 +174,7 @@ class TestLlama: "sample_packing": False, "batch_flattening": True, "bf16": True, + "tf32": tf32, "save_first_step": False, } ) diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 464812a90..27740db18 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -68,6 +68,7 @@ class TestValidationCheckDatasetConfig(BaseValidation): cfg, capabilities={ "bf16": "false", + "tf32": "false", "n_gpu": 1, "compute_capability": "8.0", }, diff --git a/tests/utils/schemas/validation/test_moe_quant.py b/tests/utils/schemas/validation/test_moe_quant.py index 2c34582c3..52b6f52c5 100644 --- a/tests/utils/schemas/validation/test_moe_quant.py +++ b/tests/utils/schemas/validation/test_moe_quant.py @@ -8,7 +8,13 @@ from axolotl.utils.dict import DictDefault @pytest.fixture() def gpu_caps(): - return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1} + return { + "compute_capability": "sm_89", + "bf16": True, + "tf32": False, + "n_gpu": 1, + "n_node": 1, + } @pytest.fixture()