flipped the slice
This commit is contained in:
@@ -23,7 +23,7 @@ def kd_loss_function(
|
||||
|
||||
# Slice student logits to match the teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, -teacher_seq_len:, :
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
|
||||
Reference in New Issue
Block a user