upgrade to support latest transformers release (#2984)

* upgrade to support latest transformers release

* bump mistral common too

* Fix dependencies
This commit is contained in:
Wing Lian
2025-07-27 17:05:12 -04:00
committed by GitHub
parent 430be216d8
commit 1d2aa1e467
6 changed files with 29 additions and 19 deletions

View File

@@ -37,14 +37,14 @@ jobs:
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: vllm axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.7.1
axolotl_extras: axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]

View File

@@ -19,5 +19,7 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov if CODECOV_TOKEN is available
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true if [ -n "$CODECOV_TOKEN" ]; then
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
fi

View File

@@ -13,13 +13,13 @@ packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.16.0 peft==0.16.0
transformers==4.53.2 transformers==4.54.0
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.9.0 accelerate==1.9.0
datasets==4.0.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.19.1 trl==0.19.1
hf_xet==1.1.2 hf_xet==1.1.5
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3
mistral-common==1.7.0 mistral-common==1.8.3

View File

@@ -68,9 +68,10 @@ def parse_requirements(extras_require_map):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
_install_requires.append("xformers==0.0.30") _install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else: else:
_install_requires.append("xformers==0.0.31.post1") _install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3") _install_requires.append("xformers==0.0.29.post3")
@@ -84,7 +85,9 @@ def parse_requirements(extras_require_map):
else: else:
_install_requires.append("xformers>=0.0.28.post3") _install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version)) _install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27") _install_requires.append("xformers>=0.0.27")
@@ -114,10 +117,10 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"], "flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.0.post2", "flash-attn==2.8.2",
"ring-flash-attn>=0.1.5", "ring-flash-attn>=0.1.7",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],
"deepspeed": [ "deepspeed": [
@@ -151,13 +154,12 @@ extras_require = {
"ray[train]", "ray[train]",
], ],
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.10.0",
], ],
"llmcompressor": [ "llmcompressor": [
"llmcompressor==0.5.1", "llmcompressor==0.5.1",
], ],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require extras_require
) )

View File

@@ -500,6 +500,7 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs[arg] = getattr(self.cfg, arg) training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = ( training_args_kwargs["per_device_eval_batch_size"] = (

View File

@@ -18,10 +18,15 @@ import transformers
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils
from ring_flash_attn import ring_flash_attn_func from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import ( from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
_flash_supports_window_size,
is_flash_attn_greater_or_equal, try:
) from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -112,7 +117,7 @@ def create_flash_attn_forward_varlen_llama3(
# Handle sliding window # Handle sliding window
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window_size _flash_supports_window
and sliding_window is not None and sliding_window is not None
and key_states.shape[1] > sliding_window and key_states.shape[1] > sliding_window
) )