Add FSDP v2 swap memory support + QLoRA compatibility fixes (#3167)
Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
committed by
GitHub
parent
7fa8ac40cd
commit
850c1a5f8d
@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@@ -277,6 +278,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||
|
||||
# Disable memory pinning if requested
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
|
||||
fsdp2_plugin.cpu_offload.pin_memory = False
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
@@ -341,7 +347,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
@@ -816,21 +816,22 @@ class OptimizationValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
|
||||
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
|
||||
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
|
||||
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
|
||||
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
|
||||
if fsdp_config and fsdp_version == 2:
|
||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||
load_in_8bit or load_in_4bit
|
||||
):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_cpu_offload_pin_memory(cls, data):
|
||||
if not (fsdp_config := data.get("fsdp_config")):
|
||||
return data
|
||||
|
||||
if fsdp_config.get("cpu_offload_pin_memory") is False:
|
||||
if str(data.get("fsdp_version")) != "2":
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2"
|
||||
)
|
||||
return self
|
||||
if not fsdp_config.get("offload_params"):
|
||||
raise ValueError(
|
||||
"disabling cpu_offload_pin_memory requires enabling offload_params"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -595,6 +595,10 @@ def setup_fsdp_envs(cfg):
|
||||
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
|
||||
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
|
||||
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
|
||||
cfg.fsdp_config.cpu_offload_pin_memory
|
||||
).lower()
|
||||
if cfg.fsdp_config.auto_wrap_policy:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
|
||||
if cfg.fsdp_config.transformer_layer_cls_to_wrap:
|
||||
|
||||
Reference in New Issue
Block a user