fix check for fp8 capability (#3324)
* fix check for fp8 capability * handle non-cuda compute * reduce concurrency of tests
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -114,7 +114,7 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
@@ -196,7 +196,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ flex_attention: true
|
|||||||
flex_attn_compile_kwargs:
|
flex_attn_compile_kwargs:
|
||||||
dynamic: false
|
dynamic: false
|
||||||
mode: max-autotune-no-cudagraphs
|
mode: max-autotune-no-cudagraphs
|
||||||
save_strategy: no
|
|
||||||
torch_compile: true
|
torch_compile: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -227,6 +227,7 @@ def load_cfg(
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
|
"fp8": compute_supports_fp8(),
|
||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
@@ -259,3 +260,11 @@ def load_cfg(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_supports_fp8() -> bool:
|
||||||
|
try:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
return compute_capability >= (9, 0)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from accelerate.utils import is_fp8_available
|
||||||
from annotated_types import MinLen
|
from annotated_types import MinLen
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fp8(self):
|
||||||
|
if self.fp8 and not self.capabilities.fp8:
|
||||||
|
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
|
||||||
|
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
|
||||||
|
raise ValueError(
|
||||||
|
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user