upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)
* upgrade transformers==4.55.1 * also upgrade bnb * remove bnb params4bit patch (upstreamed) * use latest causal-conv1d * fix patching ring-flash-attn with now missing imports --------- Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||||
python3 -m pip cache purge
|
python3 -m pip cache purge
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.1
|
bitsandbytes==0.47.0
|
||||||
# triton 3.4.0 is not compatible with CCE
|
# triton 3.4.0 is not compatible with CCE
|
||||||
triton>=3.0.0,<3.4.0
|
triton>=3.0.0,<3.4.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
@@ -14,7 +14,7 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.17.0
|
peft==0.17.0
|
||||||
transformers==4.55.0
|
transformers==4.55.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.10.0
|
accelerate==1.10.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
|
|||||||
@@ -285,12 +285,10 @@ class PatchManager:
|
|||||||
and self.cfg.adapter == "qlora"
|
and self.cfg.adapter == "qlora"
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
apply_bnb_torch_function_patch,
|
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
apply_init_unsharded_param_patch,
|
apply_init_unsharded_param_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
apply_init_sharded_param_patch()
|
apply_init_sharded_param_patch()
|
||||||
apply_init_unsharded_param_patch()
|
apply_init_unsharded_param_patch()
|
||||||
|
|
||||||
|
|||||||
@@ -9,73 +9,12 @@ Params4bit parameters.
|
|||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patched_torch_function(cls, func, types, args=(), kwargs=None):
|
|
||||||
"""
|
|
||||||
Patched version of Params4bit.__torch_function__ for preserving Params4bit
|
|
||||||
class identity and attributes.
|
|
||||||
"""
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
if func in [torch.chunk, torch.split]:
|
|
||||||
tensor = args[0]
|
|
||||||
result = Parameter.__torch_function__(func, types, args, kwargs)
|
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
return tuple(
|
|
||||||
cls(
|
|
||||||
data=chunk,
|
|
||||||
requires_grad=tensor.requires_grad,
|
|
||||||
quant_state=tensor.quant_state,
|
|
||||||
blocksize=tensor.blocksize,
|
|
||||||
compress_statistics=tensor.compress_statistics,
|
|
||||||
quant_type=tensor.quant_type,
|
|
||||||
quant_storage=tensor.quant_storage,
|
|
||||||
module=tensor.module,
|
|
||||||
bnb_quantized=tensor.bnb_quantized,
|
|
||||||
)
|
|
||||||
for chunk in result
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
data=result,
|
|
||||||
requires_grad=tensor.requires_grad,
|
|
||||||
quant_state=tensor.quant_state,
|
|
||||||
blocksize=tensor.blocksize,
|
|
||||||
compress_statistics=tensor.compress_statistics,
|
|
||||||
quant_type=tensor.quant_type,
|
|
||||||
quant_storage=tensor.quant_storage,
|
|
||||||
module=tensor.module,
|
|
||||||
bnb_quantized=tensor.bnb_quantized,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Parameter.__torch_function__(func, types, args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def apply_bnb_torch_function_patch():
|
|
||||||
"""
|
|
||||||
Patch Params4bit.__torch_function__ using Axolotl-style approach.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if patching succeeded, False otherwise.
|
|
||||||
"""
|
|
||||||
from bitsandbytes.nn.modules import Params4bit
|
|
||||||
|
|
||||||
Params4bit.__torch_function__ = classmethod(patched_torch_function)
|
|
||||||
|
|
||||||
LOG.info("Successfully patched Params4bit.__torch_function__")
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def apply_init_sharded_param_patch():
|
def apply_init_sharded_param_patch():
|
||||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
|
|||||||
@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
|
|||||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||||
|
|
||||||
try:
|
try: # pylint: disable=duplicate-code
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from transformers.modeling_flash_attention_utils import (
|
try:
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
from transformers.modeling_flash_attention_utils import (
|
||||||
)
|
_flash_supports_window_size as _flash_supports_window,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
_flash_supports_window = True
|
||||||
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,15 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import DeviceMesh
|
from torch.distributed import DeviceMesh
|
||||||
|
|
||||||
try:
|
try: # pylint: disable=duplicate-code
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from transformers.modeling_flash_attention_utils import (
|
try:
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
from transformers.modeling_flash_attention_utils import (
|
||||||
)
|
_flash_supports_window_size as _flash_supports_window,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
_flash_supports_window = True
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# pylint: disable=too-many-boolean-expressions
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -1251,10 +1252,26 @@ class ComplexValidationMixin:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import transformers.modeling_flash_attention_utils
|
import transformers.modeling_flash_attention_utils
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
|
transformers.modeling_flash_attention_utils._flash_supports_window = (
|
||||||
transformers.modeling_flash_attention_utils._flash_supports_window
|
True
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||||
|
"_flash_supports_window",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||||
|
"_flash_supports_window_size",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||||
|
"is_flash_attn_greater_or_equal",
|
||||||
|
is_flash_attn_greater_or_equal,
|
||||||
)
|
)
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
|
|||||||
@@ -1,126 +1,28 @@
|
|||||||
"""Integration tests for FSDP Params4bit patches."""
|
"""Integration tests for FSDP2 Params4bit patches."""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
|
||||||
apply_bnb_torch_function_patch,
|
|
||||||
patched_torch_function,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_params4bit():
|
|
||||||
"""Create a mock Params4bit instance with test attributes."""
|
|
||||||
mock_instance = Mock()
|
|
||||||
mock_instance.requires_grad = True
|
|
||||||
mock_instance.quant_state = "test_state"
|
|
||||||
mock_instance.blocksize = 128
|
|
||||||
mock_instance.compress_statistics = True
|
|
||||||
mock_instance.quant_type = "fp4"
|
|
||||||
mock_instance.quant_storage = "test_storage"
|
|
||||||
mock_instance.module = "test_module"
|
|
||||||
mock_instance.bnb_quantized = True
|
|
||||||
return mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
class TestBnbTorchFunctionPatch:
|
|
||||||
"""Test the Params4bit.__torch_function__ patch."""
|
|
||||||
|
|
||||||
def test_apply_patch(self):
|
|
||||||
"""Test that the patch can be applied."""
|
|
||||||
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
|
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
assert hasattr(mock_cls, "__torch_function__")
|
|
||||||
assert isinstance(mock_cls.__torch_function__, classmethod)
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
|
|
||||||
"""Test that torch.chunk preserves Params4bit attributes."""
|
|
||||||
mock_cls = Mock()
|
|
||||||
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
|
|
||||||
|
|
||||||
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
|
|
||||||
result = patched_torch_function(
|
|
||||||
mock_cls,
|
|
||||||
torch.chunk,
|
|
||||||
(type(mock_params4bit),),
|
|
||||||
args=(mock_params4bit, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, tuple)
|
|
||||||
assert len(result) == 2
|
|
||||||
|
|
||||||
# Check that Params4bit constructor was called with preserved attributes
|
|
||||||
assert mock_cls.call_count == 2
|
|
||||||
for call in mock_cls.call_args_list:
|
|
||||||
kwargs = call[1]
|
|
||||||
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
|
|
||||||
assert kwargs["quant_state"] == mock_params4bit.quant_state
|
|
||||||
assert kwargs["blocksize"] == mock_params4bit.blocksize
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
def test_other_functions_fallback(self, mock_params4bit):
|
|
||||||
"""Test that non-chunk/split functions use Parameter fallback."""
|
|
||||||
mock_cls = Mock()
|
|
||||||
fallback_result = torch.tensor([5, 6, 7])
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
|
|
||||||
) as mock_fallback:
|
|
||||||
result = patched_torch_function(
|
|
||||||
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should call Parameter.__torch_function__ and return its result
|
|
||||||
mock_fallback.assert_called_once()
|
|
||||||
assert result is fallback_result
|
|
||||||
mock_cls.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
class TestFSDPPatchIntegration:
|
class TestFSDPPatchIntegration:
|
||||||
"""Test FSDP patch integration."""
|
"""Test FSDP patch integration."""
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_all_patches_together(self):
|
def test_fsdp2_init_patches(self):
|
||||||
"""Test that all patches can be applied together."""
|
"""Test that all patches can be applied together."""
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
apply_init_unsharded_param_patch,
|
apply_init_unsharded_param_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store original methods before patching
|
|
||||||
original_torch_function = getattr(
|
|
||||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
|
||||||
)
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
original_init_sharded = FSDPParam._init_sharded_param
|
original_init_sharded = FSDPParam._init_sharded_param
|
||||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||||
|
|
||||||
# Apply patches
|
# Apply patches
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
apply_init_sharded_param_patch()
|
apply_init_sharded_param_patch()
|
||||||
apply_init_unsharded_param_patch()
|
apply_init_unsharded_param_patch()
|
||||||
|
|
||||||
# Verify patches were applied
|
|
||||||
current_torch_function = getattr(
|
|
||||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
|
||||||
)
|
|
||||||
if original_torch_function is not None:
|
|
||||||
assert (
|
|
||||||
current_torch_function != original_torch_function
|
|
||||||
), "Params4bit.__torch_function__ was not patched"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
current_torch_function is not None
|
|
||||||
), "Params4bit.__torch_function__ was not added"
|
|
||||||
|
|
||||||
# Check that FSDP methods were patched
|
|
||||||
assert (
|
assert (
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
FSDPParam._init_sharded_param
|
FSDPParam._init_sharded_param
|
||||||
|
|||||||
Reference in New Issue
Block a user