Files
axolotl/tests/e2e/integrations/test_kd.py
Wing Lian ccc94da8ad KD fix w/ online distillation (#2700) [skip ci]
* kd fixes

* fix collator setup

* fix input args

* better handling to drop string fields for kd with raw dataset

* kd trainer has kd temp as part of the init

* drop top_k before softmax

* simplfy and remove zscore

* WIP chunked KD loss with autograd wrapper

* more fixes and liger-type chunked loss

* collator cls for plugins

* remove debugging

* additional plugin collator kwargs, don't scale up kd loss by t^2

* don't need temp arg to distill method

* online kd wip

* add close to comment block

* suport sampling params/max new tokens

* handle when no custom collator is used in plugins

* logsumexp trick:

* fix check

* shift off the first empty token

* fix length of padding

* use max not min

* temp scale kd loss at end

* support for dynamic plugin training args mixins and symmetric kl

* chore: lint

* fix trainer callback base class

* Fix decay

* accept compressed responses for smaller wire payload

* post-rebase lint

* more KD updates

* increase hyperparams_count for gradients for added normalize_topk

* fix to remove attention_mask

* rename vars for consistency

* fix rebase issues

* default to dropping last batch in multipack batch sampler

* improve handling of train len

* init collator_cls_and_kwargs

* explicit drop_last=False when checking for multipack completeness

* use separate v2 loader for kd

* fix kd tests to use subprocess so it picks up kd training args

* default value for kd_beta arg

* use updated dataset for ci

* longer timeout for e2e
2025-06-17 12:09:13 -04:00

146 lines
4.6 KiB
Python

"""
e2e tests for kd trainer support in Axolotl
"""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "Qwen/Qwen3-0.6B",
"tokenizer_config": "winglian/qwen3-14b-math",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
"kd_beta": 0.0,
"kd_normalize_topk": True,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
"type": "chat_template",
"split": "train",
"split_thinking": True,
"eot_tokens": ["<|im_end|>"],
"data_files": ["train/batch-000000.parquet"],
},
],
"skip_prepare_dataset": True,
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.00001,
"bf16": "auto",
"gradient_checkpointing": True,
"flash_attention": True,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"eos_token": "<|eot_id|>",
},
"max_steps": 5,
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
}
class TestKnowledgeDistillation:
"""
Test case for Knowledge Distillation
"""
# While this will run on torch 2.4.x without torch_compile enabled
# the VRAM requirement is higher than what is available in CI
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
)
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
cfg = DictDefault(
{
"load_in_8bit": load_in_8bit,
"torch_compile": False,
"adapter": "lora",
"peft_use_dora": True,
"lora_target_linear": True,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)