make the kd e2e fit in vram for ci and add lora version
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -207,7 +207,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.5.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
@@ -253,7 +253,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.4.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ e2e tests for kd trainer support in Axolotl
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from e2e.utils import check_tensorboard
|
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -23,7 +23,7 @@ def min_cfg(temp_dir):
|
|||||||
],
|
],
|
||||||
"liger_rms_norm": True,
|
"liger_rms_norm": True,
|
||||||
"liger_glu_activation": True,
|
"liger_glu_activation": True,
|
||||||
"torch_compile": False,
|
"torch_compile": True,
|
||||||
"chat_template": "llama3",
|
"chat_template": "llama3",
|
||||||
"kd_trainer": True,
|
"kd_trainer": True,
|
||||||
"kd_ce_alpha": 0.1,
|
"kd_ce_alpha": 0.1,
|
||||||
@@ -44,11 +44,11 @@ def min_cfg(temp_dir):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"sequence_len": 4096,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 1,
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -70,6 +70,9 @@ class TestKnowledgeDistillation:
|
|||||||
Test case for Knowledge Distillation
|
Test case for Knowledge Distillation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# While this will run on torch 2.4.x without torch_compile enabled
|
||||||
|
# the VRAM requirement is higher than what is available in CI
|
||||||
|
@require_torch_2_5_1
|
||||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||||
cfg = DictDefault(kd_min_cfg)
|
cfg = DictDefault(kd_min_cfg)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -83,3 +86,32 @@ class TestKnowledgeDistillation:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"load_in_8bit",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_8bit": load_in_8bit,
|
||||||
|
"torch_compile": False,
|
||||||
|
"adapter": "lora",
|
||||||
|
"peft_use_dora": True,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
}
|
||||||
|
| kd_min_cfg
|
||||||
|
)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user