From e12a2130e990313bb0bce66be8fbbe5b856094dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Oct 2024 11:00:45 -0400 Subject: [PATCH] first pass at pytorch 2.5.0 support (#1982) * first pass at pytorch 2.5.0 support * attempt to install causal_conv1d with mamba * gracefully handle missing xformers * fix import * fix incorrect version, add 2.5.0 * increase tests timeout --- .github/workflows/main.yml | 10 +++ .github/workflows/multi-gpu-e2e.yml | 13 +++- .github/workflows/nightlies.yml | 10 +++ .github/workflows/tests-nightly.yml | 9 ++- .github/workflows/tests.yml | 10 ++- cicd/Dockerfile.jinja | 1 - cicd/multigpu.py | 2 +- cicd/tests.py | 2 +- docker/Dockerfile | 1 - setup.py | 5 +- .../monkeypatch/llama_attn_hijack_flash.py | 61 ++++++------------- src/axolotl/monkeypatch/xformers_/__init__.py | 51 ++++++++++++++++ 12 files changed, 120 insertions(+), 55 deletions(-) create mode 100644 src/axolotl/monkeypatch/xformers_/__init__.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c27dbedef..47a4c7f11 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,6 +29,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -86,6 +91,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index ab886c67f..d9f0ce7e6 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -21,10 +21,17 @@ jobs: pytorch: 2.3.1 axolotl_extras: num_gpus: 2 - - cuda: 121 - cuda_version: 12.1.1 + - cuda: 124 + cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.3.1 + pytorch: 2.4.1 + axolotl_extras: + num_gpus: 2 + nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 axolotl_extras: num_gpus: 2 nightly_build: "true" diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 17c76c24e..55123a902 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -28,6 +28,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -85,6 +90,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 8c9e1f49e..56eaae239 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.1"] + pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] timeout-minutes: 20 steps: @@ -95,6 +95,13 @@ jobs: num_gpus: 1 axolotl_extras: nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + num_gpus: 1 + axolotl_extras: + nightly_build: "true" steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a798bdd5c..e679f4101 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.1"] + pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] timeout-minutes: 20 steps: @@ -72,7 +72,7 @@ jobs: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 60 + timeout-minutes: 90 needs: [pre-commit, pytest] strategy: @@ -97,6 +97,12 @@ jobs: pytorch: 2.4.1 num_gpus: 1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + num_gpus: 1 + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 11ce8d8ba..3b082a15b 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -23,7 +23,6 @@ RUN git fetch origin +$GITHUB_REF && \ git checkout FETCH_HEAD # If AXOLOTL_EXTRAS is set, append it in brackets -RUN pip install causal_conv1d RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ diff --git a/cicd/multigpu.py b/cicd/multigpu.py index be10fbc73..da726b473 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str): @stub.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=45 * 60, + timeout=60 * 60, cpu=8.0, memory=131072 * N_GPUS, ) diff --git a/cicd/tests.py b/cicd/tests.py index 9c2d830cb..9ebce9815 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str): @stub.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=45 * 60, + timeout=60 * 60, cpu=8.0, memory=131072, ) diff --git a/docker/Dockerfile b/docker/Dockerfile index 2b106f1ed..4872b3907 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/setup.py b/setup.py index 7d9568dbf..1153d6968 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,9 @@ def parse_requirements(): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 4): + if (major, minor) >= (2, 5): + _install_requires.pop(_install_requires.index(xformers_version)) + elif (major, minor) >= (2, 4): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") @@ -102,6 +104,7 @@ setup( ], "mamba-ssm": [ "mamba-ssm==1.2.0.post1", + "causal_conv1d", ], "auto-gptq": [ "auto-gptq==0.5.1", diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4c3571ea4..c804d0c6b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -22,7 +22,6 @@ from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) -from xformers.ops import SwiGLU from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name @@ -44,7 +43,19 @@ except ImportError: LOG = logging.getLogger("axolotl") +def is_xformers_available() -> bool: + try: + import xformers # pylint: disable=unused-import # noqa: F401 + + return True + except ImportError: + return False + + def is_xformers_swiglu_available() -> bool: + if not is_xformers_available(): + return False + from xformers.ops.common import get_xformers_operator try: @@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool: def replace_llama_mlp_with_swiglu(model): + if is_xformers_swiglu_available(): + from axolotl.monkeypatch.xformers_ import FusedMLP + else: + raise RuntimeError("xformers SwiGLU not available for this environment") + for name, module in model.named_modules(): if isinstance(module, LlamaMLP): mlp = FusedMLP( @@ -181,49 +197,6 @@ class FusedAttention(LlamaAttention): set_module_name(model, name, new_attn) -class FusedMLP(torch.nn.Module): - """ - Fused MLP layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - ): - super().__init__() - self.config = config - self.swiglu = SwiGLU( - in_features=config.hidden_size, - hidden_features=config.intermediate_size, - bias=False, - _pack_weights=True, - ) - # overwrite initialized weights with pretrained weights - self.swiglu.w12.weight.data = torch.cat( - (gate_proj.weight.data, up_proj.weight.data), dim=0 - ) - self.swiglu.w3.weight.data = down_proj.weight.data - - def _post_training(self, model, name): - w1, w2 = torch.split( # pylint: disable=invalid-name - self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 - ) - - # Assign the split weights back to the original layers - new_mlp = LlamaMLP(self.config) - new_mlp.gate_proj.weight.data = w1 - new_mlp.up_proj.weight.data = w2 - new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data - - set_module_name(model, name, new_mlp) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return self.swiglu(x) - - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( diff --git a/src/axolotl/monkeypatch/xformers_/__init__.py b/src/axolotl/monkeypatch/xformers_/__init__.py new file mode 100644 index 000000000..bddc036b2 --- /dev/null +++ b/src/axolotl/monkeypatch/xformers_/__init__.py @@ -0,0 +1,51 @@ +""" +Fused MLP layer for incrementally improved training efficiency +""" +import torch +from transformers.models.llama.modeling_llama import LlamaMLP +from xformers.ops import SwiGLU + +from axolotl.monkeypatch.utils import set_module_name + + +class FusedMLP(torch.nn.Module): + """ + Fused MLP layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + gate_proj: torch.nn.Linear, + up_proj: torch.nn.Linear, + down_proj: torch.nn.Linear, + ): + super().__init__() + self.config = config + self.swiglu = SwiGLU( + in_features=config.hidden_size, + hidden_features=config.intermediate_size, + bias=False, + _pack_weights=True, + ) + # overwrite initialized weights with pretrained weights + self.swiglu.w12.weight.data = torch.cat( + (gate_proj.weight.data, up_proj.weight.data), dim=0 + ) + self.swiglu.w3.weight.data = down_proj.weight.data + + def _post_training(self, model, name): + w1, w2 = torch.split( # pylint: disable=invalid-name + self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 + ) + + # Assign the split weights back to the original layers + new_mlp = LlamaMLP(self.config) + new_mlp.gate_proj.weight.data = w1 + new_mlp.up_proj.weight.data = w2 + new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data + + set_module_name(model, name, new_mlp) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return self.swiglu(x)