diff --git a/requirements.txt b/requirements.txt index 78ced5728..4b5a12bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ packaging==23.2 peft==0.15.0 transformers==4.50.3 tokenizers>=0.21.1 -accelerate==1.5.2 +accelerate @ git+https://github.com/S1ro1/accelerate.git@dev/fsdp2 datasets==3.5.0 deepspeed==0.15.4 trl==0.16.0 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c370707b6..895375a0a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -538,6 +538,8 @@ def setup_deepspeed_env(cfg, stage=None): def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" + if str(cfg.fsdp_version) == "2": + os.environ["FSDP_VERSION"] = "2" if cfg.fsdp_config.fsdp_activation_checkpointing: os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" if cfg.fsdp_config.fsdp_offload_params: @@ -556,6 +558,10 @@ def setup_fsdp_envs(cfg): os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ( cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ) + if cfg.fsdp_config.fsdp_reshard_after_forward is not None: + os.environ["FSDP_RESHARD_AFTER_FORWARD"] = ( + "true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false" + ) def prepare_optim_env(cfg):