diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index e8adfab41..6473f26f0 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -164,6 +164,8 @@ class AxolotlKDTrainer(AxolotlTrainer): target_token_ids = inputs.pop("target_token_ids") target_mask = inputs.pop("target_mask") + seq_len = target_token_ids.shape[1] + if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: @@ -171,12 +173,19 @@ class AxolotlKDTrainer(AxolotlTrainer): inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) - student_logits = outputs["logits"] + # FIXME: account for tokenizer.padding_side + student_logits = outputs["logits"][:, :seq_len, :].contiguous() + + shift_logits = student_logits[..., :-1, :].contiguous() + shift_target_logprobs = target_logprobs[..., 1:, :].contiguous() + shift_target_token_ids = target_token_ids[..., 1:, :].contiguous() + shift_target_mask = target_mask[..., 1:, :].contiguous() + loss_kd = kd_loss_function( - student_logits, - target_token_ids, - target_logprobs, - target_mask, + shift_logits, + shift_target_token_ids, + shift_target_logprobs, + shift_target_mask, num_items_in_batch=num_items_in_batch, kd_temperature=self.args.kd_temperature, )