upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)

* upgrade transformers==4.55.1

* also upgrade bnb

* remove bnb params4bit patch (upstreamed)

* use latest causal-conv1d

* fix patching ring-flash-attn with now missing imports

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
Wing Lian
2025-08-13 19:41:07 -04:00
committed by GitHub
parent e0a2523a3b
commit 09145de8fa
8 changed files with 38 additions and 176 deletions

View File

@@ -285,12 +285,10 @@ class PatchManager:
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()

View File

@@ -9,73 +9,12 @@ Params4bit parameters.
import importlib
import inspect
import torch
from torch.nn import Parameter
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return Parameter.__torch_function__(func, types, args, kwargs)
# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.
Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit
Params4bit.__torch_function__ = classmethod(patched_torch_function)
LOG.info("Successfully patched Params4bit.__torch_function__")
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""

View File

@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

View File

@@ -15,12 +15,15 @@ import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-boolean-expressions
import json
import sys
import tempfile
from pathlib import Path
@@ -1251,10 +1252,26 @@ class ComplexValidationMixin:
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window_size",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"is_flash_attn_greater_or_equal",
is_flash_attn_greater_or_equal,
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception: