diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f7..85d241b5d 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -27,7 +27,6 @@ def add_options_from_dataclass(config_class: Type[Any]): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) ) - if field_type == bool: field_name = field.name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index a74813e3a..6562af176 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -2,8 +2,6 @@ Simple end-to-end test for Cut Cross Entropy integration """ -from pathlib import Path - import pytest from axolotl.cli import load_datasets @@ -13,6 +11,8 @@ from axolotl.utils import get_pytorch_version from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault +from ..utils import check_model_output_exists + # pylint: disable=duplicate-code @@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( "attention_type", @@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index ce9299b92..9154bf9b8 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -1,7 +1,6 @@ """ Simple end-to-end test for Liger integration """ -from pathlib import Path from e2e.utils import require_torch_2_4_1 @@ -11,6 +10,8 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault +from ..utils import check_model_output_exists + class LigerIntegrationTestCase: """ @@ -60,7 +61,7 @@ class LigerIntegrationTestCase: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 def test_llama_w_flce(self, temp_dir): @@ -105,4 +106,4 @@ class LigerIntegrationTestCase: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index b0ada9230..08b3bf0da 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -5,7 +5,6 @@ E2E tests for multipack fft llama using 4d attention masks import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import require_torch_2_3_1, with_temp_dir +from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -67,7 +66,7 @@ class Test4dMultipackLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_torch_lora_packing(self, temp_dir): @@ -111,4 +110,4 @@ class Test4dMultipackLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 183843b7b..791d955b2 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -4,7 +4,6 @@ E2E tests for lora llama import logging import os -from pathlib import Path import pytest from transformers.utils import is_torch_bf16_gpu_available @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -82,7 +81,7 @@ class TestFAXentropyLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index d9d715103..69516810f 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -5,7 +5,6 @@ E2E tests for falcon import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestFalconPatched(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -109,4 +108,4 @@ class TestFalconPatched(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 36b7442d9..23a0adfc0 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -73,4 +72,4 @@ class TestFusedLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 0f2539daf..d0fdd918a 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -5,7 +5,6 @@ E2E tests for llama w/ S2 attn import logging import os import unittest -from pathlib import Path import pytest @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +70,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_s2_attn(self, temp_dir): @@ -111,4 +110,4 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index be2f133fb..634e544d2 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -76,7 +75,7 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @with_temp_dir @@ -126,4 +125,4 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 6685fb9d5..e93863e09 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft_packing(self, temp_dir): @@ -110,4 +109,4 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 684baaaff..f87c34fd1 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -5,7 +5,6 @@ E2E tests for mixtral import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -66,7 +65,7 @@ class TestMixtral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -108,4 +107,4 @@ class TestMixtral(unittest.TestCase): "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ ) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 7b5bf92df..852ac7bec 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestPhiMultipack(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_packed(self, temp_dir): @@ -120,4 +119,4 @@ class TestPhiMultipack(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 7d82ea8c3..5639d2eae 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -6,7 +6,6 @@ import logging import os import re import subprocess -from pathlib import Path from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import most_recent_subdir +from ..utils import check_model_output_exists, most_recent_subdir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -83,7 +82,7 @@ class TestResumeLlama: cli_args = TrainerCliArgs() train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 0c0ee8610..492bc1c23 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -3,7 +3,6 @@ e2e tests for unsloth qlora """ import logging import os -from pathlib import Path import pytest @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -77,7 +76,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -127,7 +126,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -182,7 +181,7 @@ class TestUnslothQLoRA: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 4a705922f..f8109373a 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -15,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,7 +68,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_dpo_nll_lora(self, temp_dir): @@ -113,7 +113,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_dpo_use_weighting(self, temp_dir): @@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir @@ -203,7 +203,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_ipo_lora(self, temp_dir): @@ -247,7 +247,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir def test_orpo_lora(self, temp_dir): @@ -294,7 +294,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip(reason="Fix the implementation") @with_temp_dir @@ -358,4 +358,4 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 6e5ebd05f..222d620ae 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -5,7 +5,6 @@ E2E tests for llama pretrain import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -62,7 +61,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" @@ -106,7 +105,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index c76699a7c..117de6635 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -5,7 +5,6 @@ E2E tests for falcon import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +70,7 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_added_vocab(self, temp_dir): @@ -124,7 +123,7 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -163,4 +162,4 @@ class TestFalcon(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 1ce9d60b9..4384bb61e 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -4,7 +4,8 @@ E2E tests for llama import logging import os -from pathlib import Path + +from e2e.utils import check_model_output_exists from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -60,7 +61,7 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): # pylint: disable=duplicate-code @@ -103,7 +104,7 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): # pylint: disable=duplicate-code @@ -142,4 +143,4 @@ class TestLlama: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 62fb63c47..d13b10659 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -5,7 +5,6 @@ E2E tests for llama pretrain import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -64,4 +63,4 @@ class TestPretrainLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 1d583a326..250cf418c 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,7 +67,7 @@ class TestLlamaVision(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_llama_vision_multimodal_dataset(self, temp_dir): @@ -113,4 +112,4 @@ class TestLlamaVision(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index d06be60b9..a7ead64a5 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,4 +64,4 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 8755fa4d5..a1fc30862 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path import pytest @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,4 +64,4 @@ class TestMamba(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 57d85e51e..2e79fec8d 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from transformers.utils import is_torch_bf16_gpu_available @@ -15,7 +14,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -112,4 +111,4 @@ class TestMistral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index d4dad14ef..6792d05a6 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -5,7 +5,6 @@ E2E tests for mixtral import logging import os import unittest -from pathlib import Path import torch from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -79,7 +78,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_wo_fa2(self, temp_dir): @@ -133,7 +132,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_w_fa2(self, temp_dir): @@ -190,7 +189,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_wo_fa2(self, temp_dir): @@ -247,7 +246,7 @@ class TestMixtral(unittest.TestCase): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -287,4 +286,4 @@ class TestMixtral(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index f69d0500f..f1bbaafd5 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -5,7 +5,6 @@ E2E tests for custom optimizers using Llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import require_torch_2_5_1, with_temp_dir +from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,7 +64,7 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir @require_torch_2_5_1 @@ -109,7 +108,7 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): @@ -145,4 +144,4 @@ class TestCustomOptimizers(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 4cc6bcdcc..7a08d0c6f 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -5,7 +5,6 @@ E2E tests for lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -67,7 +66,7 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_phi_qlora(self, temp_dir): @@ -116,4 +115,4 @@ class TestPhi(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index 84582896d..fef6a3d30 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) assert ( - Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors" - ).exists() - assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists() + Path(temp_dir) / "checkpoint-100/relora/model.safetensors" + ).exists(), "Relora model checkpoint not found" check_tensorboard( temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high" diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index 27ac3e25f..c4cb705ea 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -5,7 +5,6 @@ E2E tests for reward model lora llama import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,4 +70,4 @@ class TestRewardModelLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 1e05c32c4..759d59659 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -14,6 +14,8 @@ import torch from packaging import version from tbparse import SummaryReader +from axolotl.utils.dict import DictDefault + def with_temp_dir(test_func): @wraps(test_func) @@ -93,3 +95,27 @@ def check_tensorboard( df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name assert df.value.values[-1] < lt_val, assertion_err + + +def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: + """ + helper function to check if a model output file exists after training + + checks based on adapter or not and if safetensors saves are enabled or not + """ + + if cfg.save_safetensors: + if not cfg.adapter: + assert (Path(temp_dir) / "model.safetensors").exists() + else: + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + else: + # check for both, b/c in trl, it often defaults to saving safetensors + if not cfg.adapter: + assert (Path(temp_dir) / "pytorch_model.bin").exists() or ( + Path(temp_dir) / "model.safetensors" + ).exists() + else: + assert (Path(temp_dir) / "adapter_model.bin").exists() or ( + Path(temp_dir) / "adapter_model.safetensors" + ).exists()