From 03dcf1a5ead2a1606298acafdb1448a203484804 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 22 Mar 2025 21:47:54 -0400 Subject: [PATCH] fsdp2 support --- requirements.txt | 2 +- src/axolotl/utils/trainer.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 78ced5728..4b5a12bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ packaging==23.2 peft==0.15.0 transformers==4.50.3 tokenizers>=0.21.1 -accelerate==1.5.2 +accelerate @ git+https://github.com/S1ro1/accelerate.git@dev/fsdp2 datasets==3.5.0 deepspeed==0.15.4 trl==0.16.0 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c370707b6..895375a0a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -538,6 +538,8 @@ def setup_deepspeed_env(cfg, stage=None): def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" + if str(cfg.fsdp_version) == "2": + os.environ["FSDP_VERSION"] = "2" if cfg.fsdp_config.fsdp_activation_checkpointing: os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" if cfg.fsdp_config.fsdp_offload_params: @@ -556,6 +558,10 @@ def setup_fsdp_envs(cfg): os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ( cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ) + if cfg.fsdp_config.fsdp_reshard_after_forward is not None: + os.environ["FSDP_RESHARD_AFTER_FORWARD"] = ( + "true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false" + ) def prepare_optim_env(cfg):