fix check for fp8 capability (#3324)

* fix check for fp8 capability

* handle non-cuda compute

* reduce concurrency of tests
This commit is contained in:
Wing Lian
2025-12-22 13:58:25 -05:00
committed by GitHub
parent faaff6c792
commit efeb5a4e41
4 changed files with 22 additions and 3 deletions

View File

@@ -114,7 +114,7 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
df -h df -h
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h df -h
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h df -h
@@ -196,7 +196,7 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ pytest -v --durations=10 tests/cli/

View File

@@ -29,7 +29,6 @@ flex_attention: true
flex_attn_compile_kwargs: flex_attn_compile_kwargs:
dynamic: false dynamic: false
mode: max-autotune-no-cudagraphs mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true torch_compile: true
wandb_project: wandb_project:

View File

@@ -227,6 +227,7 @@ def load_cfg(
cfg, cfg,
capabilities={ capabilities={
"bf16": is_torch_bf16_gpu_available(), "bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
@@ -259,3 +260,11 @@ def load_cfg(
) )
return cfg return cfg
def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
except RuntimeError:
return False

View File

@@ -2,6 +2,7 @@
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from accelerate.utils import is_fp8_available
from annotated_types import MinLen from annotated_types import MinLen
from packaging import version from packaging import version
from pydantic import ( from pydantic import (
@@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
return self return self
@model_validator(mode="after")
def check_fp8(self):
if self.fp8 and not self.capabilities.fp8:
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
raise ValueError(
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
)
return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_w_sdpa_bf16(cls, data): def check_sample_packing_w_sdpa_bf16(cls, data):