FSDP2 + LoRA kernels (#2992)

* impl fix

* smoke tests

* patches for fsdp2 + qlora compat

* nit

* working fix

* working fix

* fix merge

* minifying patches; update bnb dep

* renaming; adding tests

* remove duplicate test, add dora guard

* generalize __torch_function__

* revert generalization

* update comments
This commit is contained in:
Dan Saunders
2025-08-03 20:05:17 -04:00
committed by GitHub
parent deac7b18a1
commit e758343cac
7 changed files with 520 additions and 23 deletions

View File

@@ -14,6 +14,7 @@ from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
@@ -54,8 +55,21 @@ def get_lora_parameters(
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
linear_A = proj.lora_A[active_adapter]
linear_B = proj.lora_B[active_adapter]
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
# We fuse linear layers + LoRA adapters calculations into a single
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
# Note that we don't apply resharding later in this module (it gets messy quickly),
# but LoRA parameters are generally small enough that this is not an issue.
if isinstance(linear_A.weight, DTensor):
linear_A.unshard()
linear_B.unshard()
A = linear_A.weight
B = linear_B.weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
@@ -102,8 +116,8 @@ def matmul_lora(
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
A, B = A.t().to(dtype), B.t().to(dtype)
out += s * X @ A @ B
return out.view(batch, seq_len, -1) if reshape else out

View File

@@ -65,6 +65,7 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -260,6 +261,23 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (

View File

@@ -0,0 +1,205 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
"""
import importlib
import inspect
import torch
from torch.nn import Parameter
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return Parameter.__torch_function__(func, types, args, kwargs)
# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.
Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit
Params4bit.__torch_function__ = classmethod(patched_torch_function)
LOG.info("Successfully patched Params4bit.__torch_function__")
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam._init_sharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
patched_param_creation = """ import bitsandbytes as bnb
if isinstance(param, bnb.nn.modules.Params4bit):
self.sharded_param = bnb.nn.modules.Params4bit(
data=sharded_param,
requires_grad=param.requires_grad,
quant_state=param.quant_state,
blocksize=param.blocksize,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
quant_storage=param.quant_storage,
module=param.module,
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def _init_sharded_param(",
"def patched_init_sharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
patched_param_creation = """ import bitsandbytes as bnb
local_tensor = self.sharded_param._local_tensor
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
self._unsharded_param = bnb.nn.modules.Params4bit(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
quant_state=local_tensor.quant_state,
blocksize=local_tensor.blocksize,
compress_statistics=local_tensor.compress_statistics,
quant_type=local_tensor.quant_type,
quant_storage=local_tensor.quant_storage,
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def init_unsharded_param(",
"def patched_init_unsharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")

View File

@@ -559,20 +559,6 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
@@ -619,7 +605,7 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernel_8bit(cls, data):
def check_lora_kernels_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
@@ -627,20 +613,36 @@ class LoRAValidationMixin:
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with 8-bit LoRA a the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernel_rl(cls, data):
def check_lora_kernels_dora(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_rl(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("rl"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with RL at the moment."
)
return data