Compare commits

...

2 Commits

4 changed files with 63 additions and 34 deletions

View File

@@ -29,7 +29,8 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy'
RUN git lfs install --skip-repo && \ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \

View File

@@ -53,7 +53,13 @@ setup(
"flash-attn==2.3.3", "flash-attn==2.3.3",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"dropout-layer-norm": [
"dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm",
],
"xentropy-cuda-lib": [
"xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed>=0.13.1", "deepspeed>=0.13.1",

View File

@@ -90,37 +90,46 @@ def replace_llama_attn_with_flash_attn(
llama_model_forward llama_model_forward
) )
# skip only if explicitly disabled
if cross_entropy: if cross_entropy:
try: patch_cross_entropy()
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm: if rms_norm:
patch_rms_norm()
def patch_cross_entropy():
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
def patch_rms_norm():
try:
try: try:
from flash_attn.ops.triton.rms_norm import RMSNorm
except ImportError:
from flash_attn.ops.rms_norm import RMSNorm from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm): class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm""" """Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps) super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm") LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError: except ImportError:
LOG.info( LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
) )
class FusedAttention(LlamaAttention): class FusedAttention(LlamaAttention):

View File

@@ -24,6 +24,12 @@ from transformers import ( # noqa: F401
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_cross_entropy as llama_patch_cross_entropy,
)
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_rms_norm as llama_patch_rms_norm,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
@@ -281,15 +287,7 @@ def load_model(
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
) )
if cfg.sample_packing: if cfg.s2_attention:
if cfg.device not in ["mps", "cpu"] and not inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention") LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn( replace_llama_attn_with_flash_attn(
packed=False, packed=False,
@@ -297,6 +295,21 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True, use_shifted_sparse_attn=True,
) )
elif cfg.device not in ["mps", "cpu"] and not inference:
if cfg.sample_packing:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
else:
if cfg.flash_attn_cross_entropy:
llama_patch_cross_entropy()
if cfg.flash_attn_rms_norm:
llama_patch_rms_norm()
elif cfg.xformers_attention: elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,