Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
|
||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_hopper
|
||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||
cfg = DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user