From 9329db9c3af25b9b5e03e3b0b6721c45939449e3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Apr 2025 07:55:54 -0400 Subject: [PATCH] fix fsdp2 config for ci --- tests/e2e/multigpu/test_llama.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 64750999a..e2b797cf5 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -14,7 +14,7 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault -from tests.e2e.utils import check_tensorboard +from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 LOG = logging.getLogger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" @@ -450,6 +450,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) + @require_torch_2_6_0 @pytest.mark.parametrize( "fsdp_reshard_after_forward", [True, False], @@ -487,7 +488,9 @@ class TestMultiGPULlama: ], "fsdp_config": { "fsdp_version": 2, - "fsdp_limit_all_gathers": True, + "fsdp_forward_prefetch": True, + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": True, "fsdp_offload_params": False, "fsdp_cpu_ram_efficient_loading": False, "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",