From 788649fe9537f4ed3774231073bde70ce5a35881 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 21:23:23 -0500 Subject: [PATCH] attempt to also run e2e tests that needs gpus (#1070) * attempt to also run e2e tests that needs gpus * fix stray quote * checkout specific github ref * dockerfile for tests with proper checkout ensure wandb is dissabled for docker pytests clear wandb env after testing clear wandb env after testing make sure to provide a default val for pop tryin skipping wandb validation tests explicitly disable wandb in the e2e tests explicitly report_to None to see if that fixes the docker e2e tests split gpu from non-gpu unit tests skip bf16 check in test for now build docker w/o cache since it uses branch name ref revert some changes now that caching is fixed skip bf16 check if on gpu w support * pytest skip for auto-gptq requirements * skip mamba tests for now, split multipack and non packed lora llama tests * split tests that use monkeypatches * fix relative import for prev commit * move other tests using monkeypatches to the correct run --- .github/workflows/tests-docker.yml | 12 +- docker/Dockerfile-tests | 40 ++++++ tests/e2e/patched/__init__.py | 0 tests/e2e/{ => patched}/test_fused_llama.py | 2 +- .../e2e/patched/test_lora_llama_multipack.py | 126 ++++++++++++++++++ .../{ => patched}/test_mistral_samplepack.py | 2 +- .../{ => patched}/test_mixtral_samplepack.py | 2 +- tests/e2e/{ => patched}/test_model_patches.py | 2 +- tests/e2e/{ => patched}/test_resume.py | 4 +- tests/e2e/test_lora_llama.py | 93 ------------- tests/e2e/test_mamba.py | 7 +- tests/e2e/test_phi.py | 12 +- tests/test_validation.py | 17 +++ 13 files changed, 214 insertions(+), 105 deletions(-) create mode 100644 docker/Dockerfile-tests create mode 100644 tests/e2e/patched/__init__.py rename tests/e2e/{ => patched}/test_fused_llama.py (98%) create mode 100644 tests/e2e/patched/test_lora_llama_multipack.py rename tests/e2e/{ => patched}/test_mistral_samplepack.py (99%) rename tests/e2e/{ => patched}/test_mixtral_samplepack.py (99%) rename tests/e2e/{ => patched}/test_model_patches.py (98%) rename tests/e2e/{ => patched}/test_resume.py (96%) diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml index 380729637..6059946fc 100644 --- a/.github/workflows/tests-docker.yml +++ b/.github/workflows/tests-docker.yml @@ -36,11 +36,19 @@ jobs: PYTORCH_VERSION="${{ matrix.pytorch }}" # Build the Docker image docker build . \ - --file ./docker/Dockerfile \ + --file ./docker/Dockerfile-tests \ --build-arg BASE_TAG=$BASE_TAG \ --build-arg CUDA=$CUDA \ + --build-arg GITHUB_REF=$GITHUB_REF \ --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \ - --tag test-axolotl + --tag test-axolotl \ + --no-cache - name: Unit Tests w docker image run: | docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ + - name: GPU Unit Tests w docker image + run: | + docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/ + - name: GPU Unit Tests monkeypatched w docker image + run: | + docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest /workspace/axolotl/tests/e2e/patched/ diff --git a/docker/Dockerfile-tests b/docker/Dockerfile-tests new file mode 100644 index 000000000..2ec94f868 --- /dev/null +++ b/docker/Dockerfile-tests @@ -0,0 +1,40 @@ +ARG BASE_TAG=main-base +FROM winglian/axolotl-base:$BASE_TAG + +ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" +ARG AXOLOTL_EXTRAS="" +ARG CUDA="118" +ENV BNB_CUDA_VERSION=$CUDA +ARG PYTORCH_VERSION="2.0.1" +ARG GITHUB_REF="main" + +ENV PYTORCH_VERSION=$PYTORCH_VERSION + +RUN apt-get update && \ + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev + +WORKDIR /workspace + +RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git + +WORKDIR /workspace/axolotl + +RUN git fetch origin +$GITHUB_REF && \ + git checkout FETCH_HEAD + +# If AXOLOTL_EXTRAS is set, append it in brackets +RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \ + else \ + pip install -e .[deepspeed,flash-attn,mamba-ssm]; \ + fi + +# So we can test the Docker image +RUN pip install pytest + +# fix so that git fetch/pull from remote works +RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ + git config --get remote.origin.fetch + +# helper for huggingface-login cli +RUN git config --global credential.helper store diff --git a/tests/e2e/patched/__init__.py b/tests/e2e/patched/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py similarity index 98% rename from tests/e2e/test_fused_llama.py rename to tests/e2e/patched/test_fused_llama.py index 513df69f9..96ff5eee8 100644 --- a/tests/e2e/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -15,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from ..utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py new file mode 100644 index 000000000..079a8e924 --- /dev/null +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -0,0 +1,126 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +import pytest +from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestLoraLlama(unittest.TestCase): + """ + Test case for Llama models using LoRA w multipack + """ + + @with_temp_dir + def test_lora_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") + @with_temp_dir + def test_lora_gptq_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "gptq": True, + "gptq_disable_exllama": True, + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "save_steps": 0.5, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/e2e/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py similarity index 99% rename from tests/e2e/test_mistral_samplepack.py rename to tests/e2e/patched/test_mistral_samplepack.py index cefbd7dc0..c0327d7ef 100644 --- a/tests/e2e/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -15,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from ..utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py similarity index 99% rename from tests/e2e/test_mixtral_samplepack.py rename to tests/e2e/patched/test_mixtral_samplepack.py index b43702a51..4eff3825a 100644 --- a/tests/e2e/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -15,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from ..utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_model_patches.py b/tests/e2e/patched/test_model_patches.py similarity index 98% rename from tests/e2e/test_model_patches.py rename to tests/e2e/patched/test_model_patches.py index eb1124464..65d372c73 100644 --- a/tests/e2e/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -9,7 +9,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer -from .utils import with_temp_dir +from ..utils import with_temp_dir class TestModelPatches(unittest.TestCase): diff --git a/tests/e2e/test_resume.py b/tests/e2e/patched/test_resume.py similarity index 96% rename from tests/e2e/test_resume.py rename to tests/e2e/patched/test_resume.py index 98ec3ac6b..dfe9e8625 100644 --- a/tests/e2e/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -17,7 +17,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import most_recent_subdir, with_temp_dir +from ..utils import most_recent_subdir, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -29,7 +29,7 @@ class TestResumeLlama(unittest.TestCase): """ @with_temp_dir - def test_resume_qlora(self, temp_dir): + def test_resume_qlora_packed(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 9d795601a..c79652bef 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -65,96 +65,3 @@ class TestLoraLlama(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "bf16": True, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_lora_gptq(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", - "model_type": "AutoModelForCausalLM", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_8bit": True, - "adapter": "lora", - "gptq": True, - "gptq_disable_exllama": True, - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "save_steps": 0.5, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 463b0ddac..8755fa4d5 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -7,6 +7,8 @@ import os import unittest from pathlib import Path +import pytest + from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train @@ -19,9 +21,10 @@ LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestMistral(unittest.TestCase): +@pytest.mark.skip(reason="skipping until upstreamed into transformers") +class TestMamba(unittest.TestCase): """ - Test case for Llama models using LoRA + Test case for Mamba models """ @with_temp_dir diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index b21fc14ff..80c748cc9 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -8,6 +8,7 @@ import unittest from pathlib import Path import pytest +from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -59,7 +60,6 @@ class TestPhi(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "paged_adamw_8bit", "lr_scheduler": "cosine", - "bf16": True, "flash_attention": True, "max_steps": 10, "save_steps": 10, @@ -67,6 +67,10 @@ class TestPhi(unittest.TestCase): "save_safetensors": True, } ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -110,9 +114,13 @@ class TestPhi(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", - "bf16": True, } ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/test_validation.py b/tests/test_validation.py index 12997b023..d2518a7df 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -6,6 +6,7 @@ import unittest from typing import Optional import pytest +from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault @@ -354,6 +355,10 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + @pytest.mark.skipif( + is_torch_bf16_gpu_available(), + reason="test should only run on gpus w/o bf16 support", + ) def test_merge_lora_no_bf16_fail(self): """ This is assumed to be run on a CPU machine, so bf16 is not supported. @@ -778,6 +783,15 @@ class ValidationWandbTest(ValidationTest): assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" assert os.environ.get("WANDB_DISABLED", "") != "true" + os.environ.pop("WANDB_PROJECT", None) + os.environ.pop("WANDB_NAME", None) + os.environ.pop("WANDB_RUN_ID", None) + os.environ.pop("WANDB_ENTITY", None) + os.environ.pop("WANDB_MODE", None) + os.environ.pop("WANDB_WATCH", None) + os.environ.pop("WANDB_LOG_MODEL", None) + os.environ.pop("WANDB_DISABLED", None) + def test_wandb_set_disabled(self): cfg = DictDefault({}) @@ -798,3 +812,6 @@ class ValidationWandbTest(ValidationTest): setup_wandb_env_vars(cfg) assert os.environ.get("WANDB_DISABLED", "") != "true" + + os.environ.pop("WANDB_PROJECT", None) + os.environ.pop("WANDB_DISABLED", None)