Add FSDP v2 swap memory support + QLoRA compatibility fixes (#3167)
Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
committed by
GitHub
parent
7fa8ac40cd
commit
850c1a5f8d
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "FDSP + QLoRA"
|
||||
title: "FSDP + QLoRA"
|
||||
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
|
||||
format:
|
||||
html:
|
||||
@@ -23,6 +23,12 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Enabling Swap for FSDP2
|
||||
|
||||
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
|
||||
|
||||
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
@@ -66,6 +66,7 @@ fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@@ -277,6 +278,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||
|
||||
# Disable memory pinning if requested
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
|
||||
fsdp2_plugin.cpu_offload.pin_memory = False
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
@@ -341,7 +347,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
@@ -816,21 +816,22 @@ class OptimizationValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
|
||||
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
|
||||
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
|
||||
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
|
||||
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
|
||||
if fsdp_config and fsdp_version == 2:
|
||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||
load_in_8bit or load_in_4bit
|
||||
):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_cpu_offload_pin_memory(cls, data):
|
||||
if not (fsdp_config := data.get("fsdp_config")):
|
||||
return data
|
||||
|
||||
if fsdp_config.get("cpu_offload_pin_memory") is False:
|
||||
if str(data.get("fsdp_version")) != "2":
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2"
|
||||
)
|
||||
return self
|
||||
if not fsdp_config.get("offload_params"):
|
||||
raise ValueError(
|
||||
"disabling cpu_offload_pin_memory requires enabling offload_params"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -595,6 +595,10 @@ def setup_fsdp_envs(cfg):
|
||||
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
|
||||
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
|
||||
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
|
||||
cfg.fsdp_config.cpu_offload_pin_memory
|
||||
).lower()
|
||||
if cfg.fsdp_config.auto_wrap_policy:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
|
||||
if cfg.fsdp_config.transformer_layer_cls_to_wrap:
|
||||
|
||||
@@ -61,12 +61,50 @@ class TestFSDPValidation:
|
||||
},
|
||||
fsdp_version=2,
|
||||
)
|
||||
validated_cfg = validate_config(cfg)
|
||||
assert validated_cfg.fsdp_version == 2
|
||||
assert validated_cfg.fsdp_config.cpu_ram_efficient_loading is True
|
||||
|
||||
def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"cpu_offload_pin_memory": False,
|
||||
"offload_params": False,
|
||||
},
|
||||
fsdp_version=2,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
|
||||
match="disabling cpu_offload_pin_memory requires enabling offload_params",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"cpu_offload_pin_memory": False,
|
||||
"offload_params": True,
|
||||
},
|
||||
fsdp_version=1,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"cpu_offload_pin_memory": False,
|
||||
"offload_params": True,
|
||||
},
|
||||
fsdp_version=2,
|
||||
)
|
||||
validated_cfg = validate_config(cfg)
|
||||
assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False
|
||||
assert validated_cfg.fsdp_config.offload_params is True
|
||||
|
||||
def test_fsdp_prefixes_removed(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
|
||||
Reference in New Issue
Block a user