Add FSDP v2 swap memory support + QLoRA compatibility fixes (#3167)

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
Grant Holmes (Ren)
2025-09-26 04:23:59 -05:00
committed by GitHub
parent 7fa8ac40cd
commit 850c1a5f8d
6 changed files with 71 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy
import functools
import os
import sys
import torch
@@ -277,6 +278,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
mesh = getattr(accelerator.state, "device_mesh", None)
# Disable memory pinning if requested
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
fsdp2_plugin.cpu_offload.pin_memory = False
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
@@ -341,7 +347,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
)
if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)

View File

@@ -816,21 +816,22 @@ class OptimizationValidationMixin:
)
return data
@model_validator(mode="after")
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
@model_validator(mode="before")
@classmethod
def check_fsdp2_cpu_offload_pin_memory(cls, data):
if not (fsdp_config := data.get("fsdp_config")):
return data
if fsdp_config.get("cpu_offload_pin_memory") is False:
if str(data.get("fsdp_version")) != "2":
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2"
)
return self
if not fsdp_config.get("offload_params"):
raise ValueError(
"disabling cpu_offload_pin_memory requires enabling offload_params"
)
return data
@model_validator(mode="before")
@classmethod

View File

@@ -595,6 +595,10 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
cfg.fsdp_config.cpu_offload_pin_memory
).lower()
if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.transformer_layer_cls_to_wrap: