fix check for fp8 capability (#3324)
* fix check for fp8 capability * handle non-cuda compute * reduce concurrency of tests
This commit is contained in:
@@ -227,6 +227,7 @@ def load_cfg(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
@@ -259,3 +260,11 @@ def load_cfg(
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def compute_supports_fp8() -> bool:
|
||||
try:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability >= (9, 0)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from accelerate.utils import is_fp8_available
|
||||
from annotated_types import MinLen
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
@@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fp8(self):
|
||||
if self.fp8 and not self.capabilities.fp8:
|
||||
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
|
||||
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
|
||||
raise ValueError(
|
||||
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user