automatically enable tf32 if supported (#3473) [skip ci]

* automatically enable tf32 if supported

* update fixtures

* handle only when True

* Address CR comments

* address readability from pr comment

* simplify
This commit is contained in:
Wing Lian
2026-03-16 23:47:00 -04:00
committed by GitHub
parent d230cbbde3
commit 830e9f7eaf
8 changed files with 29 additions and 8 deletions

View File

@@ -2,6 +2,8 @@
E2E tests for llama
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -143,7 +145,8 @@ class TestLlama:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_batch_flattening(self, temp_dir):
@pytest.mark.parametrize("tf32", ["auto", False])
def test_batch_flattening(self, tf32, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -171,6 +174,7 @@ class TestLlama:
"sample_packing": False,
"batch_flattening": True,
"bf16": True,
"tf32": tf32,
"save_first_step": False,
}
)

View File

@@ -68,6 +68,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
cfg,
capabilities={
"bf16": "false",
"tf32": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},

View File

@@ -8,7 +8,13 @@ from axolotl.utils.dict import DictDefault
@pytest.fixture()
def gpu_caps():
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
return {
"compute_capability": "sm_89",
"bf16": True,
"tf32": False,
"n_gpu": 1,
"n_node": 1,
}
@pytest.fixture()