fix fsdp2 config for ci

This commit is contained in:
Wing Lian
2025-04-06 07:55:54 -04:00
parent ad7293f617
commit 9329db9c3a

View File

@@ -14,7 +14,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -450,6 +450,7 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_2_6_0
@pytest.mark.parametrize(
"fsdp_reshard_after_forward",
[True, False],
@@ -487,7 +488,9 @@ class TestMultiGPULlama:
],
"fsdp_config": {
"fsdp_version": 2,
"fsdp_limit_all_gathers": True,
"fsdp_forward_prefetch": True,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": True,
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",