Compare commits

...

2 Commits

4 changed files with 63 additions and 34 deletions

View File

@@ -29,7 +29,8 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy'
RUN git lfs install --skip-repo && \
pip3 install awscli && \

View File

@@ -55,6 +55,12 @@ setup(
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"dropout-layer-norm": [
"dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm",
],
"xentropy-cuda-lib": [
"xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy",
],
"deepspeed": [
"deepspeed>=0.13.1",
"deepspeed-kernels",

View File

@@ -90,8 +90,14 @@ def replace_llama_attn_with_flash_attn(
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
patch_cross_entropy()
if rms_norm:
patch_rms_norm()
def patch_cross_entropy():
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
@@ -104,9 +110,12 @@ def replace_llama_attn_with_flash_attn(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm:
def patch_rms_norm():
try:
try:
from flash_attn.ops.triton.rms_norm import RMSNorm
except ImportError:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):

View File

@@ -24,6 +24,12 @@ from transformers import ( # noqa: F401
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_cross_entropy as llama_patch_cross_entropy,
)
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_rms_norm as llama_patch_rms_norm,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
@@ -281,15 +287,7 @@ def load_model(
replace_llama_attn_with_flash_attn,
)
if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.s2_attention:
if cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
@@ -297,6 +295,21 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.device not in ["mps", "cpu"] and not inference:
if cfg.sample_packing:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
else:
if cfg.flash_attn_cross_entropy:
llama_patch_cross_entropy()
if cfg.flash_attn_rms_norm:
llama_patch_rms_norm()
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,