From 98af5388ba8fd2adde847f8b868a5fb2dcf9367d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jul 2024 19:11:31 -0400 Subject: [PATCH] bump flash attention 2.5.8 -> 2.6.1 (#1738) * bump flash attention 2.5.8 -> 2.6.1 * use triton implementation of cross entropy from flash attn * add smoke test for flash attn cross entropy patch * fix args to xentropy.apply * handle tuple from triton loss fn * ensure the patch tests run independently * use the wrapper already built into flash attn for cross entropy * mark pytest as forked for patches * use pytest xdist instead of forked, since cuda doesn't like forking * limit to 1 process and use dist loadfile for pytest * change up pytest for fixture to reload transformers w monkeypathc --- cicd/Dockerfile.jinja | 2 +- requirements-tests.txt | 1 + requirements.txt | 2 +- setup.py | 4 +- src/axolotl/integrations/__init__.py | 0 .../monkeypatch/llama_attn_hijack_flash.py | 15 ++-- src/axolotl/utils/models.py | 6 ++ tests/e2e/patched/test_fa_xentropy.py | 87 +++++++++++++++++++ 8 files changed, 103 insertions(+), 14 deletions(-) create mode 100644 src/axolotl/integrations/__init__.py create mode 100644 tests/e2e/patched/test_fa_xentropy.py diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 96c312ddc..287d563c1 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -30,7 +30,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ fi # So we can test the Docker image -RUN pip install pytest +RUN pip install -r requirements-tests.txt # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ diff --git a/requirements-tests.txt b/requirements-tests.txt index e079f8a60..9cda381d0 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1 +1,2 @@ pytest +pytest-xdist diff --git a/requirements.txt b/requirements.txt index e24845a44..abefa3be3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ fire PyYAML>=6.0 requests datasets==2.19.1 -flash-attn==2.5.8 +flash-attn==2.6.1 sentencepiece wandb einops diff --git a/setup.py b/setup.py index 58d279475..82e652241 100644 --- a/setup.py +++ b/setup.py @@ -80,10 +80,10 @@ setup( dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.5.8", + "flash-attn==2.6.1", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", diff --git a/src/axolotl/integrations/__init__.py b/src/axolotl/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6d7a23f0d..9377cb03f 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -104,17 +104,12 @@ def replace_llama_attn_with_flash_attn( # skip only if explicitly disabled if cross_entropy: - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss + from flash_attn.losses.cross_entropy import CrossEntropyLoss - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.warning( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) # skip only if explicitly disabled if rms_norm: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d8eac1ce1..19745ef8b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -371,6 +371,12 @@ def load_model( rms_norm=cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) + elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm: + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) elif cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py new file mode 100644 index 000000000..0991bdd74 --- /dev/null +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -0,0 +1,87 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from importlib import reload +from pathlib import Path + +import pytest +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +@pytest.fixture(autouse=True) +def reload_transformers(): + import transformers.models.llama.modeling_llama + + yield + reload(transformers.models.llama.modeling_llama) + + +class TestFAXentropyLlama(unittest.TestCase): + """ + Test case for Llama models using LoRA w multipack + """ + + @with_temp_dir + def test_lora_packing_fa_cross_entropy(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "flash_attn_cross_entropy": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists()