first pass at pytorch 2.5.0 support (#1982)
* first pass at pytorch 2.5.0 support * attempt to install causal_conv1d with mamba * gracefully handle missing xformers * fix import * fix incorrect version, add 2.5.0 * increase tests timeout
This commit is contained in:
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -29,6 +29,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -86,6 +91,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
13
.github/workflows/multi-gpu-e2e.yml
vendored
13
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,10 +21,17 @@ jobs:
|
|||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
- cuda: 121
|
- cuda: 124
|
||||||
cuda_version: 12.1.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.4.1
|
||||||
|
axolotl_extras:
|
||||||
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -28,6 +28,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -85,6 +90,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
9
.github/workflows/tests-nightly.yml
vendored
9
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -95,6 +95,13 @@ jobs:
|
|||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -72,7 +72,7 @@ jobs:
|
|||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 60
|
timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -97,6 +97,12 @@ jobs:
|
|||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
git checkout FETCH_HEAD
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
@stub.function(
|
@stub.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=45 * 60,
|
timeout=60 * 60,
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
@stub.function(
|
@stub.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=45 * 60,
|
timeout=60 * 60,
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072,
|
memory=131072,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
5
setup.py
5
setup.py
@@ -50,7 +50,9 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 4):
|
if (major, minor) >= (2, 5):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
elif (major, minor) >= (2, 4):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
@@ -102,6 +104,7 @@ setup(
|
|||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
"mamba-ssm==1.2.0.post1",
|
"mamba-ssm==1.2.0.post1",
|
||||||
|
"causal_conv1d",
|
||||||
],
|
],
|
||||||
"auto-gptq": [
|
"auto-gptq": [
|
||||||
"auto-gptq==0.5.1",
|
"auto-gptq==0.5.1",
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
|
||||||
@@ -44,7 +43,19 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def is_xformers_available() -> bool:
|
||||||
|
try:
|
||||||
|
import xformers # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_xformers_swiglu_available() -> bool:
|
def is_xformers_swiglu_available() -> bool:
|
||||||
|
if not is_xformers_available():
|
||||||
|
return False
|
||||||
|
|
||||||
from xformers.ops.common import get_xformers_operator
|
from xformers.ops.common import get_xformers_operator
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def replace_llama_mlp_with_swiglu(model):
|
def replace_llama_mlp_with_swiglu(model):
|
||||||
|
if is_xformers_swiglu_available():
|
||||||
|
from axolotl.monkeypatch.xformers_ import FusedMLP
|
||||||
|
else:
|
||||||
|
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, LlamaMLP):
|
if isinstance(module, LlamaMLP):
|
||||||
mlp = FusedMLP(
|
mlp = FusedMLP(
|
||||||
@@ -181,49 +197,6 @@ class FusedAttention(LlamaAttention):
|
|||||||
set_module_name(model, name, new_attn)
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
class FusedMLP(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
gate_proj: torch.nn.Linear,
|
|
||||||
up_proj: torch.nn.Linear,
|
|
||||||
down_proj: torch.nn.Linear,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.swiglu = SwiGLU(
|
|
||||||
in_features=config.hidden_size,
|
|
||||||
hidden_features=config.intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
_pack_weights=True,
|
|
||||||
)
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.swiglu.w12.weight.data = torch.cat(
|
|
||||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
|
||||||
)
|
|
||||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
|
||||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the split weights back to the original layers
|
|
||||||
new_mlp = LlamaMLP(self.config)
|
|
||||||
new_mlp.gate_proj.weight.data = w1
|
|
||||||
new_mlp.up_proj.weight.data = w2
|
|
||||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_mlp)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
|
||||||
return self.swiglu(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
|
|||||||
51
src/axolotl/monkeypatch/xformers_/__init__.py
Normal file
51
src/axolotl/monkeypatch/xformers_/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import set_module_name
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
Reference in New Issue
Block a user