diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 47d4b4839..6944c6f5e 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -14,17 +14,14 @@ class TestPretrainLlama: """Test case for Llama models w pretraining""" @pytest.mark.parametrize( - "sample_packing", - [True, False], - ) - @pytest.mark.parametrize( - "pretrain_multipack_attn", - [True, False], + ("sample_packing", "pretrain_multipack_attn"), + [ + (False, False), + (True, True), + (True, False), + ], ) def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): - if not sample_packing and pretrain_multipack_attn: - return - # pylint: disable=duplicate-code cfg = DictDefault( { @@ -65,7 +62,7 @@ class TestPretrainLlama: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - loss_threshold = 3.5 + loss_threshold = 3.6 if sample_packing and not pretrain_multipack_attn: loss_threshold = 6.5 check_tensorboard(