make the kd e2e fit in vram for ci and add lora version

This commit is contained in:
Wing Lian
2025-01-08 11:07:29 -05:00
parent 1d039f5486
commit 432f65f5e6
2 changed files with 38 additions and 6 deletions

View File

@@ -207,7 +207,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
@@ -253,7 +253,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -4,7 +4,7 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -23,7 +23,7 @@ def min_cfg(temp_dir):
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": False,
"torch_compile": True,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
@@ -44,11 +44,11 @@ def min_cfg(temp_dir):
},
],
"val_set_size": 0.0,
"sequence_len": 4096,
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -70,6 +70,9 @@ class TestKnowledgeDistillation:
Test case for Knowledge Distillation
"""
# While this will run on torch 2.4.x without torch_compile enabled
# the VRAM requirement is higher than what is available in CI
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
@@ -83,3 +86,32 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
)
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
cfg = DictDefault(
{
"load_in_8bit": load_in_8bit,
"torch_compile": False,
"adapter": "lora",
"peft_use_dora": True,
"lora_target_linear": True,
"lora_r": 16,
"lora_alpha": 32,
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)