From 63146300b7c890dcc12af630e88c8a246588c55f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 25 Dec 2024 18:59:30 -0500 Subject: [PATCH] fix kd loss so it's causal (fixes repeating tokens) --- src/axolotl/core/trainers/kd.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index e8adfab41..6473f26f0 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -164,6 +164,8 @@ class AxolotlKDTrainer(AxolotlTrainer): target_token_ids = inputs.pop("target_token_ids") target_mask = inputs.pop("target_mask") + seq_len = target_token_ids.shape[1] + if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: @@ -171,12 +173,19 @@ class AxolotlKDTrainer(AxolotlTrainer): inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) - student_logits = outputs["logits"] + # FIXME: account for tokenizer.padding_side + student_logits = outputs["logits"][:, :seq_len, :].contiguous() + + shift_logits = student_logits[..., :-1, :].contiguous() + shift_target_logprobs = target_logprobs[..., 1:, :].contiguous() + shift_target_token_ids = target_token_ids[..., 1:, :].contiguous() + shift_target_mask = target_mask[..., 1:, :].contiguous() + loss_kd = kd_loss_function( - student_logits, - target_token_ids, - target_logprobs, - target_mask, + shift_logits, + shift_target_token_ids, + shift_target_logprobs, + shift_target_mask, num_items_in_batch=num_items_in_batch, kd_temperature=self.args.kd_temperature, )