From efeb5a4e41007a1e87e7ee780590938c94665899 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Dec 2025 13:58:25 -0500 Subject: [PATCH] fix check for fp8 capability (#3324) * fix check for fp8 capability * handle non-cuda compute * reduce concurrency of tests --- .github/workflows/tests.yml | 4 ++-- examples/llama-3/3b-fp8-fsdp2.yaml | 1 - src/axolotl/cli/config.py | 9 +++++++++ src/axolotl/utils/schemas/config.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cbfc15e1..0dc61b7ff 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -114,7 +114,7 @@ jobs: - name: Run tests run: | df -h - pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml df -h pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml df -h @@ -196,7 +196,7 @@ jobs: - name: Run tests run: | - pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index b7de7ca52..57b308abd 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -29,7 +29,6 @@ flex_attention: true flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs -save_strategy: no torch_compile: true wandb_project: diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 3c4ace7b0..b53c6576b 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -227,6 +227,7 @@ def load_cfg( cfg, capabilities={ "bf16": is_torch_bf16_gpu_available(), + "fp8": compute_supports_fp8(), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, @@ -259,3 +260,11 @@ def load_cfg( ) return cfg + + +def compute_supports_fp8() -> bool: + try: + compute_capability = torch.cuda.get_device_capability() + return compute_capability >= (9, 0) + except RuntimeError: + return False diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index bd6a61177..e0c9acd4d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Literal +from accelerate.utils import is_fp8_available from annotated_types import MinLen from packaging import version from pydantic import ( @@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self + @model_validator(mode="after") + def check_fp8(self): + if self.fp8 and not self.capabilities.fp8: + raise ValueError("fp8 requested, but fp8 is not supported on this GPU") + elif self.fp8 and self.capabilities.fp8 and not is_fp8_available(): + raise ValueError( + "fp8 requested, but missing one of ms-amp, transformers-engine or torchao." + ) + return self + @model_validator(mode="before") @classmethod def check_sample_packing_w_sdpa_bf16(cls, data):