FSDP2 + LoRA kernels (#2992)
* impl fix * smoke tests * patches for fsdp2 + qlora compat * nit * working fix * working fix * fix merge * minifying patches; update bnb dep * renaming; adding tests * remove duplicate test, add dora guard * generalize __torch_function__ * revert generalization * update comments
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.0
|
bitsandbytes==0.46.1
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
from bitsandbytes.functional import QuantState
|
from bitsandbytes.functional import QuantState
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
from .geglu import geglu_backward, geglu_forward
|
from .geglu import geglu_backward, geglu_forward
|
||||||
from .quantize import dequantize
|
from .quantize import dequantize
|
||||||
@@ -54,8 +55,21 @@ def get_lora_parameters(
|
|||||||
if hasattr(proj, "active_adapters")
|
if hasattr(proj, "active_adapters")
|
||||||
else proj.active_adapter
|
else proj.active_adapter
|
||||||
)
|
)
|
||||||
A = proj.lora_A[active_adapter].weight
|
|
||||||
B = proj.lora_B[active_adapter].weight
|
linear_A = proj.lora_A[active_adapter]
|
||||||
|
linear_B = proj.lora_B[active_adapter]
|
||||||
|
|
||||||
|
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
|
||||||
|
# We fuse linear layers + LoRA adapters calculations into a single
|
||||||
|
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
|
||||||
|
# Note that we don't apply resharding later in this module (it gets messy quickly),
|
||||||
|
# but LoRA parameters are generally small enough that this is not an issue.
|
||||||
|
if isinstance(linear_A.weight, DTensor):
|
||||||
|
linear_A.unshard()
|
||||||
|
linear_B.unshard()
|
||||||
|
|
||||||
|
A = linear_A.weight
|
||||||
|
B = linear_B.weight
|
||||||
s = proj.scaling[active_adapter]
|
s = proj.scaling[active_adapter]
|
||||||
|
|
||||||
quant_state = getattr(W, "quant_state", None)
|
quant_state = getattr(W, "quant_state", None)
|
||||||
@@ -102,8 +116,8 @@ def matmul_lora(
|
|||||||
del W
|
del W
|
||||||
|
|
||||||
if A is not None:
|
if A is not None:
|
||||||
A, B = A.t(), B.t()
|
A, B = A.t().to(dtype), B.t().to(dtype)
|
||||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
out += s * X @ A @ B
|
||||||
|
|
||||||
return out.view(batch, seq_len, -1) if reshape else out
|
return out.view(batch, seq_len, -1) if reshape else out
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
|
self._apply_fsdp2_bnb_patches()
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
@@ -260,6 +261,23 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_fsdp2_bnb_patches(self):
|
||||||
|
"""Apply FSDP2 BNB patches."""
|
||||||
|
if (
|
||||||
|
self.cfg.fsdp_config
|
||||||
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
|
and self.cfg.adapter == "qlora"
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_bnb_torch_function_patch,
|
||||||
|
apply_init_sharded_param_patch,
|
||||||
|
apply_init_unsharded_param_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_bnb_torch_function_patch()
|
||||||
|
apply_init_sharded_param_patch()
|
||||||
|
apply_init_unsharded_param_patch()
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import (
|
from axolotl.monkeypatch.tiled_mlp import (
|
||||||
|
|||||||
205
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
205
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||||
|
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||||
|
|
||||||
|
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||||
|
Params4bit parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_torch_function(cls, func, types, args=(), kwargs=None):
|
||||||
|
"""
|
||||||
|
Patched version of Params4bit.__torch_function__ for preserving Params4bit
|
||||||
|
class identity and attributes.
|
||||||
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
if func in [torch.chunk, torch.split]:
|
||||||
|
tensor = args[0]
|
||||||
|
result = Parameter.__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
return tuple(
|
||||||
|
cls(
|
||||||
|
data=chunk,
|
||||||
|
requires_grad=tensor.requires_grad,
|
||||||
|
quant_state=tensor.quant_state,
|
||||||
|
blocksize=tensor.blocksize,
|
||||||
|
compress_statistics=tensor.compress_statistics,
|
||||||
|
quant_type=tensor.quant_type,
|
||||||
|
quant_storage=tensor.quant_storage,
|
||||||
|
module=tensor.module,
|
||||||
|
bnb_quantized=tensor.bnb_quantized,
|
||||||
|
)
|
||||||
|
for chunk in result
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
data=result,
|
||||||
|
requires_grad=tensor.requires_grad,
|
||||||
|
quant_state=tensor.quant_state,
|
||||||
|
blocksize=tensor.blocksize,
|
||||||
|
compress_statistics=tensor.compress_statistics,
|
||||||
|
quant_type=tensor.quant_type,
|
||||||
|
quant_storage=tensor.quant_storage,
|
||||||
|
module=tensor.module,
|
||||||
|
bnb_quantized=tensor.bnb_quantized,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Parameter.__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def apply_bnb_torch_function_patch():
|
||||||
|
"""
|
||||||
|
Patch Params4bit.__torch_function__ using Axolotl-style approach.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if patching succeeded, False otherwise.
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn.modules import Params4bit
|
||||||
|
|
||||||
|
Params4bit.__torch_function__ = classmethod(patched_torch_function)
|
||||||
|
|
||||||
|
LOG.info("Successfully patched Params4bit.__torch_function__")
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def apply_init_sharded_param_patch():
|
||||||
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
# Get original source
|
||||||
|
original_source = inspect.getsource(FSDPParam._init_sharded_param)
|
||||||
|
original_source, _ = detab_code(original_source)
|
||||||
|
|
||||||
|
# Define the replacement
|
||||||
|
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||||
|
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||||
|
|
||||||
|
patched_param_creation = """ import bitsandbytes as bnb
|
||||||
|
if isinstance(param, bnb.nn.modules.Params4bit):
|
||||||
|
self.sharded_param = bnb.nn.modules.Params4bit(
|
||||||
|
data=sharded_param,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
quant_state=param.quant_state,
|
||||||
|
blocksize=param.blocksize,
|
||||||
|
compress_statistics=param.compress_statistics,
|
||||||
|
quant_type=param.quant_type,
|
||||||
|
quant_storage=param.quant_storage,
|
||||||
|
module=param.module,
|
||||||
|
bnb_quantized=param.bnb_quantized,
|
||||||
|
)
|
||||||
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
|
else:
|
||||||
|
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||||
|
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||||
|
|
||||||
|
# Apply the replacement
|
||||||
|
if original_param_creation in original_source:
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
original_param_creation, patched_param_creation
|
||||||
|
)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def _init_sharded_param(",
|
||||||
|
"def patched_init_sharded_param(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load necessary imports
|
||||||
|
module_name = FSDPParam.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
# Replace the method
|
||||||
|
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||||
|
else:
|
||||||
|
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_init_unsharded_param_patch():
|
||||||
|
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
# Get original source
|
||||||
|
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
|
||||||
|
original_source, _ = detab_code(original_source)
|
||||||
|
|
||||||
|
# Define the replacement
|
||||||
|
original_param_creation = """ self._unsharded_param = nn.Parameter(
|
||||||
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
|
)"""
|
||||||
|
|
||||||
|
patched_param_creation = """ import bitsandbytes as bnb
|
||||||
|
local_tensor = self.sharded_param._local_tensor
|
||||||
|
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
|
||||||
|
self._unsharded_param = bnb.nn.modules.Params4bit(
|
||||||
|
data=unsharded_param,
|
||||||
|
requires_grad=self.sharded_param.requires_grad,
|
||||||
|
quant_state=local_tensor.quant_state,
|
||||||
|
blocksize=local_tensor.blocksize,
|
||||||
|
compress_statistics=local_tensor.compress_statistics,
|
||||||
|
quant_type=local_tensor.quant_type,
|
||||||
|
quant_storage=local_tensor.quant_storage,
|
||||||
|
module=local_tensor.module,
|
||||||
|
bnb_quantized=local_tensor.bnb_quantized,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._unsharded_param = nn.Parameter(
|
||||||
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
|
)"""
|
||||||
|
|
||||||
|
# Apply the replacement
|
||||||
|
if original_param_creation in original_source:
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
original_param_creation, patched_param_creation
|
||||||
|
)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def init_unsharded_param(",
|
||||||
|
"def patched_init_unsharded_param(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load necessary imports
|
||||||
|
module_name = FSDPParam.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
# Replace the method
|
||||||
|
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||||
|
else:
|
||||||
|
LOG.warning("Could not find target code for patching")
|
||||||
@@ -559,20 +559,6 @@ class LoRAValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_lora_8bit(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("lora_mlp_kernel")
|
|
||||||
or data.get("lora_qkv_kernel")
|
|
||||||
or data.get("lora_o_kernel")
|
|
||||||
):
|
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_axolotl_unsloth(cls, data):
|
def check_lora_axolotl_unsloth(cls, data):
|
||||||
@@ -619,7 +605,7 @@ class LoRAValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernel_8bit(cls, data):
|
def check_lora_kernels_8bit(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
@@ -627,20 +613,36 @@ class LoRAValidationMixin:
|
|||||||
):
|
):
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||||
|
"compatible with 8-bit LoRA a the moment."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernel_rl(cls, data):
|
def check_lora_kernels_dora(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
) and data.get("peft_use_dora"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||||
|
"compatible with DoRA at the moment."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lora_kernels_rl(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
) and data.get("rl"):
|
) and data.get("rl"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||||
|
"compatible with RL at the moment."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -174,6 +174,69 @@ class TestFSDP2:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_sft_kernels(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
def test_qlora_sft(self, temp_dir):
|
def test_qlora_sft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -236,6 +299,70 @@ class TestFSDP2:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_qlora_sft_kernels(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
def test_dpo_fft(self, temp_dir):
|
def test_dpo_fft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Integration tests for FSDP Params4bit patches."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_bnb_torch_function_patch,
|
||||||
|
patched_torch_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_params4bit():
|
||||||
|
"""Create a mock Params4bit instance with test attributes."""
|
||||||
|
mock_instance = Mock()
|
||||||
|
mock_instance.requires_grad = True
|
||||||
|
mock_instance.quant_state = "test_state"
|
||||||
|
mock_instance.blocksize = 128
|
||||||
|
mock_instance.compress_statistics = True
|
||||||
|
mock_instance.quant_type = "fp4"
|
||||||
|
mock_instance.quant_storage = "test_storage"
|
||||||
|
mock_instance.module = "test_module"
|
||||||
|
mock_instance.bnb_quantized = True
|
||||||
|
return mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
class TestBnbTorchFunctionPatch:
|
||||||
|
"""Test the Params4bit.__torch_function__ patch."""
|
||||||
|
|
||||||
|
def test_apply_patch(self):
|
||||||
|
"""Test that the patch can be applied."""
|
||||||
|
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
|
||||||
|
apply_bnb_torch_function_patch()
|
||||||
|
assert hasattr(mock_cls, "__torch_function__")
|
||||||
|
assert isinstance(mock_cls.__torch_function__, classmethod)
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
|
||||||
|
"""Test that torch.chunk preserves Params4bit attributes."""
|
||||||
|
mock_cls = Mock()
|
||||||
|
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
|
||||||
|
|
||||||
|
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
|
||||||
|
result = patched_torch_function(
|
||||||
|
mock_cls,
|
||||||
|
torch.chunk,
|
||||||
|
(type(mock_params4bit),),
|
||||||
|
args=(mock_params4bit, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
# Check that Params4bit constructor was called with preserved attributes
|
||||||
|
assert mock_cls.call_count == 2
|
||||||
|
for call in mock_cls.call_args_list:
|
||||||
|
kwargs = call[1]
|
||||||
|
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
|
||||||
|
assert kwargs["quant_state"] == mock_params4bit.quant_state
|
||||||
|
assert kwargs["blocksize"] == mock_params4bit.blocksize
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
def test_other_functions_fallback(self, mock_params4bit):
|
||||||
|
"""Test that non-chunk/split functions use Parameter fallback."""
|
||||||
|
mock_cls = Mock()
|
||||||
|
fallback_result = torch.tensor([5, 6, 7])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
|
||||||
|
) as mock_fallback:
|
||||||
|
result = patched_torch_function(
|
||||||
|
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should call Parameter.__torch_function__ and return its result
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
|
assert result is fallback_result
|
||||||
|
mock_cls.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFSDPPatchIntegration:
|
||||||
|
"""Test FSDP patch integration."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_all_patches_together(self):
|
||||||
|
"""Test that all patches can be applied together."""
|
||||||
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_init_sharded_param_patch,
|
||||||
|
apply_init_unsharded_param_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original methods before patching
|
||||||
|
original_torch_function = getattr(
|
||||||
|
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||||
|
)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
original_init_sharded = FSDPParam._init_sharded_param
|
||||||
|
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||||
|
|
||||||
|
# Apply patches
|
||||||
|
apply_bnb_torch_function_patch()
|
||||||
|
apply_init_sharded_param_patch()
|
||||||
|
apply_init_unsharded_param_patch()
|
||||||
|
|
||||||
|
# Verify patches were applied
|
||||||
|
current_torch_function = getattr(
|
||||||
|
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||||
|
)
|
||||||
|
if original_torch_function is not None:
|
||||||
|
assert (
|
||||||
|
current_torch_function != original_torch_function
|
||||||
|
), "Params4bit.__torch_function__ was not patched"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
current_torch_function is not None
|
||||||
|
), "Params4bit.__torch_function__ was not added"
|
||||||
|
|
||||||
|
# Check that FSDP methods were patched
|
||||||
|
assert (
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
FSDPParam._init_sharded_param
|
||||||
|
!= original_init_sharded
|
||||||
|
), "_init_sharded_param was not patched"
|
||||||
|
assert (
|
||||||
|
FSDPParam.init_unsharded_param != original_init_unsharded
|
||||||
|
), "init_unsharded_param was not patched"
|
||||||
Reference in New Issue
Block a user