attempt to also run e2e tests that needs gpus (#1070)

* attempt to also run e2e tests that needs gpus

* fix stray quote

* checkout specific github ref

* dockerfile for tests with proper checkout

ensure wandb is dissabled for docker pytests
clear wandb env after testing
clear wandb env after testing
make sure to provide a default val for pop
tryin skipping wandb validation tests
explicitly disable wandb in the e2e tests
explicitly report_to None to see if that fixes the docker e2e tests
split gpu from non-gpu unit tests
skip bf16 check in test for now
build docker w/o cache since it uses branch name ref
revert some changes now that caching is fixed
skip bf16 check if on gpu w support

* pytest skip for auto-gptq requirements

* skip mamba tests for now, split multipack and non packed lora llama tests

* split tests that use monkeypatches

* fix relative import for prev commit

* move other tests using monkeypatches to the correct run
This commit is contained in:
Wing Lian
2024-01-09 21:23:23 -05:00
committed by GitHub
parent 9be92d1448
commit 788649fe95
13 changed files with 214 additions and 105 deletions

View File

@@ -6,6 +6,7 @@ import unittest
from typing import Optional
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@@ -354,6 +355,10 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
@pytest.mark.skipif(
is_torch_bf16_gpu_available(),
reason="test should only run on gpus w/o bf16 support",
)
def test_merge_lora_no_bf16_fail(self):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
@@ -778,6 +783,15 @@ class ValidationWandbTest(ValidationTest):
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_NAME", None)
os.environ.pop("WANDB_RUN_ID", None)
os.environ.pop("WANDB_ENTITY", None)
os.environ.pop("WANDB_MODE", None)
os.environ.pop("WANDB_WATCH", None)
os.environ.pop("WANDB_LOG_MODEL", None)
os.environ.pop("WANDB_DISABLED", None)
def test_wandb_set_disabled(self):
cfg = DictDefault({})
@@ -798,3 +812,6 @@ class ValidationWandbTest(ValidationTest):
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)