fix: Enable KD plugin support for PEFT/LoRA adapters (#3207)
- Fix _loss_function attribute not found on base model with PEFT - Fix mismatched attribute name (loss_function vs _loss_function) - Set _loss_function on unwrapped base model for PEFT - Enable previously skipped test_llama_lora_kd test - Add test config fixes for LoRA kernel compatibility Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
This commit is contained in:
@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
loss = self._loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
|
||||
@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
|
||||
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
target = self.model
|
||||
|
||||
# Unwrap PEFT wrapper
|
||||
if hasattr(target, "get_base_model"):
|
||||
target = target.get_base_model()
|
||||
|
||||
# Set on the actual model instance
|
||||
target._loss_function = loss_fn
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
|
||||
Reference in New Issue
Block a user