fix check for fp8 capability (#3324)

* fix check for fp8 capability

* handle non-cuda compute

* reduce concurrency of tests
This commit is contained in:
Wing Lian
2025-12-22 13:58:25 -05:00
committed by GitHub
parent faaff6c792
commit efeb5a4e41
4 changed files with 22 additions and 3 deletions

View File

@@ -227,6 +227,7 @@ def load_cfg(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
@@ -259,3 +260,11 @@ def load_cfg(
)
return cfg
def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
except RuntimeError:
return False

View File

@@ -2,6 +2,7 @@
from typing import Annotated, Any, Literal
from accelerate.utils import is_fp8_available
from annotated_types import MinLen
from packaging import version
from pydantic import (
@@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="after")
def check_fp8(self):
if self.fp8 and not self.capabilities.fp8:
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
raise ValueError(
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):