automatically enable tf32 if supported (#3473) [skip ci]
* automatically enable tf32 if supported * update fixtures * handle only when True * Address CR comments * address readability from pr comment * simplify
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -143,7 +145,8 @@ class TestLlama:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_batch_flattening(self, temp_dir):
|
||||
@pytest.mark.parametrize("tf32", ["auto", False])
|
||||
def test_batch_flattening(self, tf32, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -171,6 +174,7 @@ class TestLlama:
|
||||
"sample_packing": False,
|
||||
"batch_flattening": True,
|
||||
"bf16": True,
|
||||
"tf32": tf32,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -68,6 +68,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"tf32": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
|
||||
@@ -8,7 +8,13 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
@pytest.fixture()
|
||||
def gpu_caps():
|
||||
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
||||
return {
|
||||
"compute_capability": "sm_89",
|
||||
"bf16": True,
|
||||
"tf32": False,
|
||||
"n_gpu": 1,
|
||||
"n_node": 1,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
Reference in New Issue
Block a user