From bcc78d8fa393ae07f4df364d1104d63cf778c9e1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 3 Jan 2024 15:11:04 -0500 Subject: [PATCH] bump transformers and update attention class map name (#1023) * bump transformers and update attention class map name * also run the tests in docker * add mixtral e2e smoke test * fix base name for docker image in test * mixtral lora doesn't seem to work, at least check qlora * add testcase for mixtral w sample packing * check monkeypatch for flash attn multipack * also run the e2e tests in docker * use all gpus to run tests in docker ci * use privileged mode too for docker w gpus * rename the docker e2e actions for gh ci * set privileged mode for docker and update mixtral model self attn check * use fp16/bf16 for mixtral w fa2 * skip e2e tests on docker w gpus for now * tests to validate mistral and mixtral patches * fix rel import --- .github/workflows/tests-docker.yml | 62 +++++++++ requirements.txt | 2 +- src/axolotl/monkeypatch/mixtral/__init__.py | 2 +- .../monkeypatch/mixtral/modeling_mixtral.py | 8 +- src/axolotl/utils/models.py | 3 + tests/e2e/test_mixtral.py | 109 ++++++++++++++++ tests/e2e/test_mixtral_samplepack.py | 123 ++++++++++++++++++ tests/e2e/test_model_patches.py | 99 ++++++++++++++ 8 files changed, 404 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/tests-docker.yml create mode 100644 tests/e2e/test_mixtral.py create mode 100644 tests/e2e/test_mixtral_samplepack.py create mode 100644 tests/e2e/test_model_patches.py diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml new file mode 100644 index 000000000..ff30d68ea --- /dev/null +++ b/.github/workflows/tests-docker.yml @@ -0,0 +1,62 @@ +name: e2e-docker-tests + +on: + pull_request: + paths: + - '**.py' + - 'requirements.txt' + workflow_dispatch: + +jobs: + build-axolotl: + if: github.repository_owner == 'OpenAccess-AI-Collective' + # this job needs to be run on self-hosted GPU runners... + strategy: + fail-fast: false + matrix: + include: + - cuda: 118 + cuda_version: 11.8.0 + python_version: "3.10" + pytorch: 2.0.1 + axolotl_extras: + is_latest: true + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + axolotl_extras: + runs-on: [self-hosted, gpu, docker] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker metadata + id: metadata + uses: docker/metadata-action@v5 + with: + images: winglian/axolotl + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ + - name: Build and export to Docker + uses: docker/build-push-action@v5 + with: + context: . + load: true + build-args: | + BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} + CUDA=${{ matrix.cuda }} + PYTORCH_VERSION=${{ matrix.pytorch }} + file: ./docker/Dockerfile + tags: | + ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + labels: ${{ steps.metadata.outputs.labels }} + - name: Unit Tests + run: | + docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ diff --git a/requirements.txt b/requirements.txt index c1c1cbc13..f4df0dd67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ auto-gptq==0.5.1 packaging peft==0.6.0 -transformers==4.36.2 +transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.24.1 diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 418814689..74fa00f64 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn(): transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( mixtral_model_forward ) - transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[ + transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[ "flash_attention_2" ] = MixtralMultipackFlashAttention2 diff --git a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py index 34f35015f..db892530d 100644 --- a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +++ b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py @@ -261,7 +261,11 @@ def mixtral_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if ( + attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -270,7 +274,7 @@ def mixtral_model_forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = ( attention_mask diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index df6907c3d..fb2420108 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -332,15 +332,18 @@ def load_model( or cfg.is_mistral_derived_model or model_config.model_type == "mixtral" ): + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: if model_config.model_type == "mixtral": + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: + model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py new file mode 100644 index 000000000..896cc74d0 --- /dev/null +++ b/tests/e2e/test_mixtral.py @@ -0,0 +1,109 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_mixtral_samplepack.py b/tests/e2e/test_mixtral_samplepack.py new file mode 100644 index 000000000..b43702a51 --- /dev/null +++ b/tests/e2e/test_mixtral_samplepack.py @@ -0,0 +1,123 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_model_patches.py b/tests/e2e/test_model_patches.py new file mode 100644 index 000000000..eb1124464 --- /dev/null +++ b/tests/e2e/test_model_patches.py @@ -0,0 +1,99 @@ +""" +E2E smoke tests to check that the monkeypatches are in place for certain configurations +""" + +import unittest + +from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +from .utils import with_temp_dir + + +class TestModelPatches(unittest.TestCase): + """ + TestCases for the multipack monkey patches + """ + + @with_temp_dir + def test_mixtral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + + @with_temp_dir + def test_mistral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "openaccess-ai-collective/tiny-mistral", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mistral_attn_hijack_flash" + in model.model.layers[0].self_attn.forward.__module__ + )