fsdp2 support
This commit is contained in:
@@ -14,7 +14,7 @@ packaging==23.2
|
||||
peft==0.15.0
|
||||
transformers==4.50.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
accelerate @ git+https://github.com/S1ro1/accelerate.git@dev/fsdp2
|
||||
datasets==3.5.0
|
||||
deepspeed==0.15.4
|
||||
trl==0.16.0
|
||||
|
||||
@@ -538,6 +538,8 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if str(cfg.fsdp_version) == "2":
|
||||
os.environ["FSDP_VERSION"] = "2"
|
||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
||||
if cfg.fsdp_config.fsdp_offload_params:
|
||||
@@ -556,6 +558,10 @@ def setup_fsdp_envs(cfg):
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
|
||||
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
)
|
||||
if cfg.fsdp_config.fsdp_reshard_after_forward is not None:
|
||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = (
|
||||
"true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false"
|
||||
)
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
|
||||
Reference in New Issue
Block a user