upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)
* upgrade transformers==4.55.1 * also upgrade bnb * remove bnb params4bit patch (upstreamed) * use latest causal-conv1d * fix patching ring-flash-attn with now missing imports --------- Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
@@ -285,12 +285,10 @@ class PatchManager:
|
||||
and self.cfg.adapter == "qlora"
|
||||
):
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_bnb_torch_function_patch,
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
apply_bnb_torch_function_patch()
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
|
||||
@@ -9,73 +9,12 @@ Params4bit parameters.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patched_torch_function(cls, func, types, args=(), kwargs=None):
|
||||
"""
|
||||
Patched version of Params4bit.__torch_function__ for preserving Params4bit
|
||||
class identity and attributes.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func in [torch.chunk, torch.split]:
|
||||
tensor = args[0]
|
||||
result = Parameter.__torch_function__(func, types, args, kwargs)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
return tuple(
|
||||
cls(
|
||||
data=chunk,
|
||||
requires_grad=tensor.requires_grad,
|
||||
quant_state=tensor.quant_state,
|
||||
blocksize=tensor.blocksize,
|
||||
compress_statistics=tensor.compress_statistics,
|
||||
quant_type=tensor.quant_type,
|
||||
quant_storage=tensor.quant_storage,
|
||||
module=tensor.module,
|
||||
bnb_quantized=tensor.bnb_quantized,
|
||||
)
|
||||
for chunk in result
|
||||
)
|
||||
|
||||
return cls(
|
||||
data=result,
|
||||
requires_grad=tensor.requires_grad,
|
||||
quant_state=tensor.quant_state,
|
||||
blocksize=tensor.blocksize,
|
||||
compress_statistics=tensor.compress_statistics,
|
||||
quant_type=tensor.quant_type,
|
||||
quant_storage=tensor.quant_storage,
|
||||
module=tensor.module,
|
||||
bnb_quantized=tensor.bnb_quantized,
|
||||
)
|
||||
|
||||
return Parameter.__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def apply_bnb_torch_function_patch():
|
||||
"""
|
||||
Patch Params4bit.__torch_function__ using Axolotl-style approach.
|
||||
|
||||
Returns:
|
||||
True if patching succeeded, False otherwise.
|
||||
"""
|
||||
from bitsandbytes.nn.modules import Params4bit
|
||||
|
||||
Params4bit.__torch_function__ = classmethod(patched_torch_function)
|
||||
|
||||
LOG.info("Successfully patched Params4bit.__torch_function__")
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def apply_init_sharded_param_patch():
|
||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||
|
||||
@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||
|
||||
try:
|
||||
try: # pylint: disable=duplicate-code
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
|
||||
@@ -15,12 +15,15 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import DeviceMesh
|
||||
|
||||
try:
|
||||
try: # pylint: disable=duplicate-code
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1251,10 +1252,26 @@ class ComplexValidationMixin:
|
||||
|
||||
try:
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from transformers.utils import is_flash_attn_greater_or_equal
|
||||
|
||||
# pylint: disable=protected-access
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window = (
|
||||
True
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_flash_supports_window",
|
||||
True,
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_flash_supports_window_size",
|
||||
True,
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"is_flash_attn_greater_or_equal",
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
|
||||
Reference in New Issue
Block a user