Compare commits
20 Commits
attn-imple
...
fa3-hopper
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bdf4b1c23 | ||
|
|
d6f64a3684 | ||
|
|
0735454782 | ||
|
|
bb6464c4c6 | ||
|
|
323a9cb153 | ||
|
|
b22150751f | ||
|
|
8c4bc59bfc | ||
|
|
a064f1c9b4 | ||
|
|
fb5ef6d445 | ||
|
|
34b68ddaae | ||
|
|
9a3d0c919b | ||
|
|
bd34d0b861 | ||
|
|
37220ab90a | ||
|
|
e1b74d710b | ||
|
|
79daf5b934 | ||
|
|
ddd7c55576 | ||
|
|
65c6c98a76 | ||
|
|
4ef2e8293f | ||
|
|
c126d5cd04 | ||
|
|
9b0be4f15c |
11
.github/workflows/base.yml
vendored
11
.github/workflows/base.yml
vendored
@@ -47,11 +47,18 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
suffix: "-hopper"
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -87,7 +94,7 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
|||||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -32,21 +32,25 @@ jobs:
|
|||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
|
suffix: "-hopper"
|
||||||
|
num_gpus: 2
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -68,7 +72,6 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
|
||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG CUDA_VERSION="11.8.0"
|
ARG CUDA_VERSION="12.4.1"
|
||||||
ARG CUDNN_VERSION="8"
|
ARG CUDNN_VERSION=""
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
|
|||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.10"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.5.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="124"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
@@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
pip3 install flash-attn==2.7.4.post1; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -629,6 +629,49 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
|
use_fa3 = False
|
||||||
|
if self.cfg.use_flash_attention_3 is True:
|
||||||
|
use_fa3 = True
|
||||||
|
elif self.cfg.use_flash_attention_3 == "auto":
|
||||||
|
if torch.cuda.get_device_capability() >= (9, 0):
|
||||||
|
# FA3 is only available on Hopper GPUs and newer
|
||||||
|
use_fa3 = True
|
||||||
|
if not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
use_fa3 = False
|
||||||
|
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
# this can happen when use_flash_attention_3 is explicity set to True
|
||||||
|
# and flash_attn_interface is not installed
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install the flash_attn_interface library to use Flash Attention 3.x"
|
||||||
|
)
|
||||||
|
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||||
|
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||||
|
from flash_attn_interface import (
|
||||||
|
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
|
flash_attn_func_v3_wrapper
|
||||||
|
)
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
flash_attn_varlen_func_v3_wrapper
|
||||||
|
)
|
||||||
|
LOG.info("Switched to Flash Attention v3")
|
||||||
|
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
@@ -699,6 +742,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mllama()
|
patch_mllama()
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm":
|
||||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||||
replace_btlm_attn_with_flash_attn,
|
replace_btlm_attn_with_flash_attn,
|
||||||
@@ -706,6 +750,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if (
|
if (
|
||||||
self.model_config.model_type == "stablelm_epoch"
|
self.model_config.model_type == "stablelm_epoch"
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_qkv: bool | None = None
|
flash_attn_fuse_qkv: bool | None = None
|
||||||
flash_attn_fuse_mlp: bool | None = None
|
flash_attn_fuse_mlp: bool | None = None
|
||||||
flash_optimum: bool | None = None
|
flash_optimum: bool | None = None
|
||||||
|
use_flash_attention_3: Literal["auto"] | bool | None = None
|
||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ def temp_dir():
|
|||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@@ -434,6 +435,19 @@ def cleanup_monkeypatches():
|
|||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
original_trainer_training_step = Trainer.training_step
|
original_trainer_training_step = Trainer.training_step
|
||||||
|
original_fa_func = None
|
||||||
|
original_fa_varlen_func = None
|
||||||
|
if (
|
||||||
|
importlib.util.find_spec("flash_attn")
|
||||||
|
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
|
||||||
|
and hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||||
|
original_fa_varlen_func = (
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||||
|
)
|
||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
@@ -444,6 +458,11 @@ def cleanup_monkeypatches():
|
|||||||
original_trainer_inner_training_loop
|
original_trainer_inner_training_loop
|
||||||
)
|
)
|
||||||
Trainer.training_step = original_trainer_training_step
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
if original_fa_func:
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
original_fa_varlen_func
|
||||||
|
)
|
||||||
|
|
||||||
# Reset other known monkeypatches
|
# Reset other known monkeypatches
|
||||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
@@ -458,6 +477,7 @@ def cleanup_monkeypatches():
|
|||||||
("transformers.trainer",),
|
("transformers.trainer",),
|
||||||
("transformers", ["Trainer"]),
|
("transformers", ["Trainer"]),
|
||||||
("transformers.loss.loss_utils",),
|
("transformers.loss.loss_utils",),
|
||||||
|
("transformers.modeling_flash_attention_utils",),
|
||||||
]
|
]
|
||||||
for module_name_tuple in modules_to_reset:
|
for module_name_tuple in modules_to_reset:
|
||||||
module_name = module_name_tuple[0]
|
module_name = module_name_tuple[0]
|
||||||
|
|||||||
@@ -101,7 +101,13 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
@pytest.mark.parametrize(
|
||||||
|
"use_flash_attention_3",
|
||||||
|
[False, "auto"],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(
|
||||||
|
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -138,6 +144,7 @@ class TestMultiGPULlama:
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
|
"use_flash_attention_3": use_flash_attention_3,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for packed training
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -14,18 +13,17 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
from .utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestPackedLlama(unittest.TestCase):
|
class TestPackedLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Packed training of llama models
|
Test case for Packed training of llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_packed(self, temp_dir):
|
def test_loss_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user