fix: Enable KD plugin support for PEFT/LoRA adapters (#3207)

- Fix _loss_function attribute not found on base model with PEFT
- Fix mismatched attribute name (loss_function vs _loss_function)
- Set _loss_function on unwrapped base model for PEFT
- Enable previously skipped test_llama_lora_kd test
- Add test config fixes for LoRA kernel compatibility

Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
This commit is contained in:
Hitesh Sagtani
2025-10-10 18:27:00 +05:30
committed by GitHub
parent 153edcfe79
commit bc2ffb8204
3 changed files with 16 additions and 4 deletions

View File

@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss # self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function( loss = self._loss_function(
self.lm_head.weight, self.lm_head.weight,
hidden_states, hidden_states,
target_token_ids, target_token_ids,

View File

@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss self.args.kd_alpha, # kd loss
self.args.kd_temperature, self.args.kd_temperature,
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
compute_ce_loss=bool(self.args.kd_ce_alpha), compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk, normalize_topk=self.args.kd_normalize_topk,
) )
target = self.model
# Unwrap PEFT wrapper
if hasattr(target, "get_base_model"):
target = target.get_base_model()
# Set on the actual model instance
target._loss_function = loss_fn
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()

View File

@@ -104,7 +104,6 @@ class TestKnowledgeDistillation:
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
) )
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"load_in_8bit", "load_in_8bit",
[True, False], [True, False],
@@ -120,6 +119,10 @@ class TestKnowledgeDistillation:
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,
"lora_dropout": 0.0, "lora_dropout": 0.0,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"lora_mlp_kernel": False,
"lora_qkv_kernel": False,
"lora_o_kernel": False,
} }
| kd_min_cfg | kd_min_cfg
) )