Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
af3f981f51 allow 8bit optims with fsdp2 2025-04-06 07:49:06 -04:00
Wing Lian
52b96031b4 use accelerate release 1.6.0 2025-04-06 07:49:05 -04:00
Wing Lian
03dcf1a5ea fsdp2 support 2025-04-06 07:49:05 -04:00
3 changed files with 8 additions and 1 deletions

View File

@@ -14,7 +14,7 @@ packaging==23.2
peft==0.15.0
transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
accelerate==1.6.0
datasets==3.5.0
deepspeed==0.15.4
trl==0.16.0

View File

@@ -950,6 +950,7 @@ class AxolotlInputConfig(
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
and str(data["fsdp_config"].get("fsdp_version")) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"

View File

@@ -538,6 +538,8 @@ def setup_deepspeed_env(cfg, stage=None):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if str(cfg.fsdp_version) == "2":
os.environ["FSDP_VERSION"] = "2"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
@@ -556,6 +558,10 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
)
if cfg.fsdp_config.fsdp_reshard_after_forward is not None:
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = (
"true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false"
)
def prepare_optim_env(cfg):