From 5a1c1b82d48b57d38cb46c157049c78ffc8b3a0f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 8 Jan 2025 11:07:29 -0500 Subject: [PATCH] make the kd e2e fit in vram for ci and add lora version --- .github/workflows/tests.yml | 4 ++-- tests/e2e/integrations/test_kd.py | 40 +++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a2a0e801e..c6f408655 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -207,7 +207,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.1 + pytorch: 2.5.1 num_gpus: 1 axolotl_extras: steps: @@ -253,7 +253,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 147f4fc78..919a73dce 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -4,7 +4,7 @@ e2e tests for kd trainer support in Axolotl from pathlib import Path import pytest -from e2e.utils import check_tensorboard +from e2e.utils import check_tensorboard, require_torch_2_5_1 from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -23,7 +23,7 @@ def min_cfg(temp_dir): ], "liger_rms_norm": True, "liger_glu_activation": True, - "torch_compile": False, + "torch_compile": True, "chat_template": "llama3", "kd_trainer": True, "kd_ce_alpha": 0.1, @@ -44,11 +44,11 @@ def min_cfg(temp_dir): }, ], "val_set_size": 0.0, - "sequence_len": 4096, + "sequence_len": 2048, "sample_packing": True, "pad_to_sequence_len": True, "gradient_accumulation_steps": 2, - "micro_batch_size": 2, + "micro_batch_size": 1, "num_epochs": 1, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", @@ -70,6 +70,9 @@ class TestKnowledgeDistillation: Test case for Knowledge Distillation """ + # While this will run on torch 2.4.x without torch_compile enabled + # the VRAM requirement is higher than what is available in CI + @require_torch_2_5_1 def test_llama_kd(self, temp_dir, kd_min_cfg): cfg = DictDefault(kd_min_cfg) # pylint: disable=duplicate-code @@ -83,3 +86,32 @@ class TestKnowledgeDistillation: check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" ) + + @pytest.mark.parametrize( + "load_in_8bit", + [True, False], + ) + def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit): + cfg = DictDefault( + { + "load_in_8bit": load_in_8bit, + "torch_compile": False, + "adapter": "lora", + "peft_use_dora": True, + "lora_target_linear": True, + "lora_r": 16, + "lora_alpha": 32, + } + | kd_min_cfg + ) + # pylint: disable=duplicate-code + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + check_tensorboard( + temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" + )