automatically enable tf32 if supported (#3473) [skip ci]
* automatically enable tf32 if supported * update fixtures * handle only when True * Address CR comments * address readability from pr comment * simplify
This commit is contained in:
@@ -11,7 +11,7 @@ from urllib.parse import urlparse
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.telemetry.errors import send_errors
|
from axolotl.telemetry.errors import send_errors
|
||||||
@@ -310,6 +310,7 @@ def load_cfg(
|
|||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
"fp8": compute_supports_fp8(),
|
"fp8": compute_supports_fp8(),
|
||||||
|
"tf32": is_torch_tf32_available(),
|
||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_args_kwargs["bf16_full_eval"] = True
|
training_args_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def resolve_dtype(cfg):
|
|||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
else:
|
else:
|
||||||
if cfg.tf32:
|
if cfg.tf32 is True:
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
if is_torch_greater_or_equal("2.9.0"):
|
if is_torch_greater_or_equal("2.9.0"):
|
||||||
torch.backends.fp32_precision = "tf32"
|
torch.backends.fp32_precision = "tf32"
|
||||||
|
|||||||
@@ -407,9 +407,11 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "No AMP (automatic mixed precision)"},
|
json_schema_extra={"description": "No AMP (automatic mixed precision)"},
|
||||||
) # for non-AMP cases
|
) # for non-AMP cases
|
||||||
tf32: bool | None = Field(
|
tf32: Literal["auto"] | bool | None = Field(
|
||||||
default=None,
|
default="auto",
|
||||||
json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"},
|
json_schema_extra={
|
||||||
|
"description": "bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
float32: bool | None = None
|
float32: bool | None = None
|
||||||
|
|
||||||
@@ -1218,6 +1220,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tf32(self):
|
||||||
|
if self.tf32 == "auto":
|
||||||
|
self.tf32 = self.capabilities.tf32
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fp8(self):
|
def check_fp8(self):
|
||||||
if self.fp8 and not self.capabilities.fp8:
|
if self.fp8 and not self.capabilities.fp8:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class GPUCapabilities(BaseModel):
|
|||||||
|
|
||||||
bf16: bool = Field(default=False)
|
bf16: bool = Field(default=False)
|
||||||
fp8: bool = Field(default=False)
|
fp8: bool = Field(default=False)
|
||||||
|
tf32: bool = Field(default=False)
|
||||||
n_gpu: int = Field(default=1)
|
n_gpu: int = Field(default=1)
|
||||||
n_node: int = Field(default=1)
|
n_node: int = Field(default=1)
|
||||||
compute_capability: Optional[str] = Field(default=None)
|
compute_capability: Optional[str] = Field(default=None)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
E2E tests for llama
|
E2E tests for llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
@@ -143,7 +145,8 @@ class TestLlama:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_batch_flattening(self, temp_dir):
|
@pytest.mark.parametrize("tf32", ["auto", False])
|
||||||
|
def test_batch_flattening(self, tf32, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
@@ -171,6 +174,7 @@ class TestLlama:
|
|||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"batch_flattening": True,
|
"batch_flattening": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
|
"tf32": tf32,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": "false",
|
"bf16": "false",
|
||||||
|
"tf32": "false",
|
||||||
"n_gpu": 1,
|
"n_gpu": 1,
|
||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,7 +8,13 @@ from axolotl.utils.dict import DictDefault
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def gpu_caps():
|
def gpu_caps():
|
||||||
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
return {
|
||||||
|
"compute_capability": "sm_89",
|
||||||
|
"bf16": True,
|
||||||
|
"tf32": False,
|
||||||
|
"n_gpu": 1,
|
||||||
|
"n_node": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
|||||||
Reference in New Issue
Block a user