From 0402d1975925bcd88a8094c836b22d50e61c0cb2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 5 Nov 2023 08:02:31 -0500 Subject: [PATCH] make sure to cleanup tmp output_dir for e2e tests --- tests/e2e/test_fused_llama.py | 6 +++--- tests/e2e/test_lora_llama.py | 14 +++++++------- tests/e2e/test_mistral.py | 10 +++++----- tests/e2e/test_mistral_samplepack.py | 10 +++++----- tests/e2e/test_phi.py | 15 ++++++++++----- tests/utils.py | 22 ++++++++++++++++++++++ 6 files changed, 52 insertions(+), 25 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/e2e/test_fused_llama.py b/tests/e2e/test_fused_llama.py index 9363f333c..6979707c3 100644 --- a/tests/e2e/test_fused_llama.py +++ b/tests/e2e/test_fused_llama.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -import tempfile import unittest from pathlib import Path @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault +from tests.utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -25,9 +25,9 @@ class TestFusedLlama(unittest.TestCase): Test case for Llama models using Fused layers """ - def test_fft_packing(self): + @with_temp_dir + def test_fft_packing(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "JackFram/llama-68m", diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 4f50d8194..dcdfdb0ec 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -import tempfile import unittest from pathlib import Path @@ -13,6 +12,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault +from tests.utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -23,9 +23,9 @@ class TestLoraLlama(unittest.TestCase): Test case for Llama models using LoRA """ - def test_lora(self): + @with_temp_dir + def test_lora(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "JackFram/llama-68m", @@ -65,9 +65,9 @@ class TestLoraLlama(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(output_dir) / "adapter_model.bin").exists() - def test_lora_packing(self): + @with_temp_dir + def test_lora_packing(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "JackFram/llama-68m", @@ -109,9 +109,9 @@ class TestLoraLlama(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(output_dir) / "adapter_model.bin").exists() - def test_lora_gptq(self): + @with_temp_dir + def test_lora_gptq(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index f2928a727..7db40f005 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -import tempfile import unittest from pathlib import Path @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault +from tests.utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase): Test case for Llama models using LoRA """ - def test_lora(self): + @with_temp_dir + def test_lora(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", @@ -70,9 +70,9 @@ class TestMistral(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(output_dir) / "adapter_model.bin").exists() - def test_ft(self): + @with_temp_dir + def test_ft(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", diff --git a/tests/e2e/test_mistral_samplepack.py b/tests/e2e/test_mistral_samplepack.py index 5fadf0959..cf1aa57f5 100644 --- a/tests/e2e/test_mistral_samplepack.py +++ b/tests/e2e/test_mistral_samplepack.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -import tempfile import unittest from pathlib import Path @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault +from tests.utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase): Test case for Llama models using LoRA """ - def test_lora_packing(self): + @with_temp_dir + def test_lora_packing(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", @@ -71,9 +71,9 @@ class TestMistral(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(output_dir) / "adapter_model.bin").exists() - def test_ft_packing(self): + @with_temp_dir + def test_ft_packing(self, output_dir): # pylint: disable=duplicate-code - output_dir = tempfile.mkdtemp() cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index f9ea52ea2..4f121038d 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -4,14 +4,15 @@ E2E tests for lora llama import logging import os -import tempfile import unittest +from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault +from tests.utils import with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -22,7 +23,8 @@ class TestPhi(unittest.TestCase): Test case for Llama models using LoRA """ - def test_ft(self): + @with_temp_dir + def test_ft(self, output_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -52,7 +54,7 @@ class TestPhi(unittest.TestCase): "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 1, - "output_dir": tempfile.mkdtemp(), + "output_dir": output_dir, "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", @@ -64,8 +66,10 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(output_dir) / "pytorch_model.bin").exists() - def test_ft_packed(self): + @with_temp_dir + def test_ft_packed(self, output_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -95,7 +99,7 @@ class TestPhi(unittest.TestCase): "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 1, - "output_dir": tempfile.mkdtemp(), + "output_dir": output_dir, "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", @@ -107,3 +111,4 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(output_dir) / "pytorch_model.bin").exists() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..1c044778f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,22 @@ +""" +helper utils for tests +""" + +import shutil +import tempfile +from functools import wraps + + +def with_temp_dir(test_func): + @wraps(test_func) + def wrapper(*args, **kwargs): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Pass the temporary directory to the test function + test_func(temp_dir, *args, **kwargs) + finally: + # Clean up the directory after the test + shutil.rmtree(temp_dir) + + return wrapper