From 1447beb1328f8298ae0b7333091609f895fc596e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 24 Apr 2025 13:01:43 -0400 Subject: [PATCH] make sure to validate the config before normalizing so defaults get set (#2554) * make sure to validate the config before normalizing so defaults get set * validation not needed for particular test * remove duplicate validations * set qlora correctly --- tests/e2e/integrations/test_cut_cross_entropy.py | 5 ++++- tests/e2e/integrations/test_liger.py | 4 +++- tests/e2e/patched/test_4d_multipack_llama.py | 4 +++- tests/e2e/patched/test_falcon_samplepack.py | 4 +++- tests/e2e/patched/test_fused_llama.py | 3 ++- tests/e2e/patched/test_llama_s2_attention.py | 4 +++- tests/e2e/patched/test_lora_llama_multipack.py | 4 +++- tests/e2e/patched/test_mistral_samplepack.py | 4 +++- tests/e2e/patched/test_mixtral_samplepack.py | 3 ++- tests/e2e/patched/test_model_patches.py | 4 +++- tests/e2e/patched/test_phi_multipack.py | 6 ++++-- tests/e2e/patched/test_resume.py | 3 ++- tests/e2e/patched/test_unsloth_qlora.py | 5 ++++- tests/e2e/test_embeddings_lr.py | 1 + tests/e2e/test_llama_vision.py | 1 + tests/e2e/test_phi.py | 3 ++- tests/e2e/test_process_reward_model_smollm2.py | 3 ++- tests/test_exact_deduplication.py | 3 ++- 18 files changed, 47 insertions(+), 17 deletions(-) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 753934563..2ae59a15a 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version -from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists @@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration: # pylint: disable=redefined-outer-name def test_llama_w_cce(self, min_cfg, temp_dir): cfg = DictDefault(min_cfg) + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() @@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration: "bf16": "auto", } ) + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() @@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration: attention_type: True, } ) + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 03c83083d..8ecfc4746 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 @@ -54,6 +54,7 @@ class LigerIntegrationTestCase: } ) # pylint: disable=duplicate-code + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() @@ -100,6 +101,7 @@ class LigerIntegrationTestCase: } ) # pylint: disable=duplicate-code + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 7beb71145..33ba47abd 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase): "fp16": True, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase): "fp16": True, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index 62ee4f717..0034169af 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index f8f245514..51dfec5f4 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase): cfg.bf16 = True else: cfg.fp16 = True + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index b8ddf10da..3aa36772a 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -11,7 +11,7 @@ import pytest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index e544eb4fd..ab6e87e2a 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase): else: cfg.fp16 = True + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase): "lr_scheduler": "cosine", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index f9e523679..3bc0fcfbc 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 8746c923b..2d4f97084 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index ec09e0c81..8a75db52e 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,7 +6,7 @@ import unittest import transformers -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase): "eval_steps": 10, } ) + cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) load_model(cfg, tokenizer, inference=False) @@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase): "eval_steps": 10, } ) + cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) load_model(cfg, tokenizer, inference=False) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 70b3ea124..c42ed8baf 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, with_temp_dir @@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase): } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase): "sample_packing": True, "flash_attention": True, "pad_to_sequence_len": True, - "load_in_8bit": False, + "load_in_4bit": True, "adapter": "qlora", "lora_r": 64, "lora_alpha": 32, @@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase): } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 68489ed03..a84759bae 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, most_recent_subdir @@ -68,6 +68,7 @@ class TestResumeLlama: cfg.bf16 = True else: cfg.fp16 = True + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 4cea0d26f..5f8fde6b4 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -10,7 +10,7 @@ import pytest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, check_tensorboard @@ -72,6 +72,7 @@ class TestUnslothQLoRA: } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -122,6 +123,7 @@ class TestUnslothQLoRA: } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -177,6 +179,7 @@ class TestUnslothQLoRA: } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 687b6637f..82b822ad6 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): "use_tensorboard": True, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 3fc12afcc..e1e496ccf 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase): "bf16": True, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 268646432..f531a17c5 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase): "tokenizer_type": "AutoTokenizer", "sequence_len": 2048, "sample_packing": False, - "load_in_8bit": False, + "load_in_4bit": True, "adapter": "qlora", "lora_r": 64, "lora_alpha": 32, @@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index 19347cf92..446facdb0 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -9,7 +9,7 @@ import unittest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, check_tensorboard, with_temp_dir @@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase): "seed": 42, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 4d069a11d..1d41a248d 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -11,7 +11,7 @@ from unittest.mock import patch import pytest from datasets import Dataset -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets @@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase): "num_epochs": 1, } ) + self.cfg_1 = validate_config(self.cfg_1) normalize_config(self.cfg_1) @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")