upgrade to support latest transformers release (#2984)
* upgrade to support latest transformers release * bump mistral common too * Fix dependencies
This commit is contained in:
@@ -500,6 +500,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
training_args_kwargs["average_tokens_across_devices"] = False
|
||||
|
||||
if self.cfg.eval_batch_size:
|
||||
training_args_kwargs["per_device_eval_batch_size"] = (
|
||||
|
||||
@@ -18,10 +18,15 @@ import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -112,7 +117,7 @@ def create_flash_attn_forward_varlen_llama3(
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
_flash_supports_window
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user