KD fix w/ online distillation (#2700) [skip ci]

* kd fixes

* fix collator setup

* fix input args

* better handling to drop string fields for kd with raw dataset

* kd trainer has kd temp as part of the init

* drop top_k before softmax

* simplfy and remove zscore

* WIP chunked KD loss with autograd wrapper

* more fixes and liger-type chunked loss

* collator cls for plugins

* remove debugging

* additional plugin collator kwargs, don't scale up kd loss by t^2

* don't need temp arg to distill method

* online kd wip

* add close to comment block

* suport sampling params/max new tokens

* handle when no custom collator is used in plugins

* logsumexp trick:

* fix check

* shift off the first empty token

* fix length of padding

* use max not min

* temp scale kd loss at end

* support for dynamic plugin training args mixins and symmetric kl

* chore: lint

* fix trainer callback base class

* Fix decay

* accept compressed responses for smaller wire payload

* post-rebase lint

* more KD updates

* increase hyperparams_count for gradients for added normalize_topk

* fix to remove attention_mask

* rename vars for consistency

* fix rebase issues

* default to dropping last batch in multipack batch sampler

* improve handling of train len

* init collator_cls_and_kwargs

* explicit drop_last=False when checking for multipack completeness

* use separate v2 loader for kd

* fix kd tests to use subprocess so it picks up kd training args

* default value for kd_beta arg

* use updated dataset for ci

* longer timeout for e2e
This commit is contained in:
Wing Lian
2025-06-17 12:09:13 -04:00
committed by GitHub
parent ba62aa65ee
commit ccc94da8ad
31 changed files with 2178 additions and 527 deletions

View File

@@ -5,10 +5,9 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@@ -17,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"base_model": "Qwen/Qwen3-0.6B",
"tokenizer_config": "winglian/qwen3-14b-math",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
@@ -31,20 +30,22 @@ def min_cfg(temp_dir):
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
"kd_beta": 0.0,
"kd_normalize_topk": True,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
"type": "chat_template",
"split": "train",
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
"split_thinking": True,
"eot_tokens": ["<|im_end|>"],
"data_files": ["train/batch-000000.parquet"],
},
],
"skip_prepare_dataset": True,
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
@@ -80,17 +81,29 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
@@ -110,12 +123,22 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
train(cfg=cfg, dataset_meta=dataset_meta)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"

View File

@@ -81,6 +81,7 @@ class TestBatchedSamplerPacking:
group_size=100000,
bin_size=200,
sequential=sequential,
drop_last=False,
)
loader = DataLoader(