Add FSDP v2 swap memory support + QLoRA compatibility fixes (#3167)

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
Grant Holmes (Ren)
2025-09-26 04:23:59 -05:00
committed by GitHub
parent 7fa8ac40cd
commit 850c1a5f8d
6 changed files with 71 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
---
title: "FDSP + QLoRA"
title: "FSDP + QLoRA"
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
format:
html:
@@ -23,6 +23,12 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Enabling Swap for FSDP2
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.

View File

@@ -66,6 +66,7 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy
import functools
import os
import sys
import torch
@@ -277,6 +278,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
mesh = getattr(accelerator.state, "device_mesh", None)
# Disable memory pinning if requested
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
fsdp2_plugin.cpu_offload.pin_memory = False
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
@@ -341,7 +347,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
)
if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)

View File

@@ -816,21 +816,22 @@ class OptimizationValidationMixin:
)
return data
@model_validator(mode="after")
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
@model_validator(mode="before")
@classmethod
def check_fsdp2_cpu_offload_pin_memory(cls, data):
if not (fsdp_config := data.get("fsdp_config")):
return data
if fsdp_config.get("cpu_offload_pin_memory") is False:
if str(data.get("fsdp_version")) != "2":
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2"
)
return self
if not fsdp_config.get("offload_params"):
raise ValueError(
"disabling cpu_offload_pin_memory requires enabling offload_params"
)
return data
@model_validator(mode="before")
@classmethod

View File

@@ -595,6 +595,10 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
cfg.fsdp_config.cpu_offload_pin_memory
).lower()
if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.transformer_layer_cls_to_wrap:

View File

@@ -61,12 +61,50 @@ class TestFSDPValidation:
},
fsdp_version=2,
)
validated_cfg = validate_config(cfg)
assert validated_cfg.fsdp_version == 2
assert validated_cfg.fsdp_config.cpu_ram_efficient_loading is True
def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": False,
},
fsdp_version=2,
)
with pytest.raises(
ValueError,
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
match="disabling cpu_offload_pin_memory requires enabling offload_params",
):
validate_config(cfg)
def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": True,
},
fsdp_version=1,
)
with pytest.raises(
ValueError,
match="FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2",
):
validate_config(cfg)
def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": True,
},
fsdp_version=2,
)
validated_cfg = validate_config(cfg)
assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False
assert validated_cfg.fsdp_config.offload_params is True
def test_fsdp_prefixes_removed(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={