flipped the slice

This commit is contained in:
Wing Lian
2024-12-19 01:21:48 -05:00
parent 283faf3909
commit 1c603da96a

View File

@@ -23,7 +23,7 @@ def kd_loss_function(
# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, -teacher_seq_len:, :
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens