diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index f7b468669..badb3460d 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -72,9 +72,9 @@ def kldiv_forward_llama_like( # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 - # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss + # self._loss_function should be LigerFusedLinearKLTopKLogprobLoss - loss = self.loss_function( + loss = self._loss_function( self.lm_head.weight, hidden_states, target_token_ids, diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 7ec43333a..0e98497a7 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model_accepts_loss_kwargs = True - self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss( + + loss_fn = LigerFusedLinearKLTopKLogprobLoss( self.args.kd_ce_alpha, # hard label loss self.args.kd_alpha, # kd loss self.args.kd_temperature, @@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer): compute_ce_loss=bool(self.args.kd_ce_alpha), normalize_topk=self.args.kd_normalize_topk, ) + target = self.model + + # Unwrap PEFT wrapper + if hasattr(target, "get_base_model"): + target = target.get_base_model() + + # Set on the actual model instance + target._loss_function = loss_fn def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index ff47b9427..d89044247 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -104,7 +104,6 @@ class TestKnowledgeDistillation: temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" ) - @pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA") @pytest.mark.parametrize( "load_in_8bit", [True, False], @@ -120,6 +119,10 @@ class TestKnowledgeDistillation: "lora_r": 16, "lora_alpha": 32, "lora_dropout": 0.0, + "lora_modules_to_save": ["embed_tokens", "lm_head"], + "lora_mlp_kernel": False, + "lora_qkv_kernel": False, + "lora_o_kernel": False, } | kd_min_cfg )