fix: Enable KD plugin support for PEFT/LoRA adapters (#3207)
- Fix _loss_function attribute not found on base model with PEFT - Fix mismatched attribute name (loss_function vs _loss_function) - Set _loss_function on unwrapped base model for PEFT - Enable previously skipped test_llama_lora_kd test - Add test config fixes for LoRA kernel compatibility Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
This commit is contained in:
@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
|
|||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||||
|
|
||||||
loss = self.loss_function(
|
loss = self._loss_function(
|
||||||
self.lm_head.weight,
|
self.lm_head.weight,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.model_accepts_loss_kwargs = True
|
self.model_accepts_loss_kwargs = True
|
||||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
|
||||||
|
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
|
||||||
self.args.kd_ce_alpha, # hard label loss
|
self.args.kd_ce_alpha, # hard label loss
|
||||||
self.args.kd_alpha, # kd loss
|
self.args.kd_alpha, # kd loss
|
||||||
self.args.kd_temperature,
|
self.args.kd_temperature,
|
||||||
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||||
normalize_topk=self.args.kd_normalize_topk,
|
normalize_topk=self.args.kd_normalize_topk,
|
||||||
)
|
)
|
||||||
|
target = self.model
|
||||||
|
|
||||||
|
# Unwrap PEFT wrapper
|
||||||
|
if hasattr(target, "get_base_model"):
|
||||||
|
target = target.get_base_model()
|
||||||
|
|
||||||
|
# Set on the actual model instance
|
||||||
|
target._loss_function = loss_fn
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
super()._set_signature_columns_if_needed()
|
super()._set_signature_columns_if_needed()
|
||||||
|
|||||||
@@ -104,7 +104,6 @@ class TestKnowledgeDistillation:
|
|||||||
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"load_in_8bit",
|
"load_in_8bit",
|
||||||
[True, False],
|
[True, False],
|
||||||
@@ -120,6 +119,10 @@ class TestKnowledgeDistillation:
|
|||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
"lora_dropout": 0.0,
|
"lora_dropout": 0.0,
|
||||||
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
|
"lora_mlp_kernel": False,
|
||||||
|
"lora_qkv_kernel": False,
|
||||||
|
"lora_o_kernel": False,
|
||||||
}
|
}
|
||||||
| kd_min_cfg
|
| kd_min_cfg
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user