flipped the slice
This commit is contained in:
@@ -23,7 +23,7 @@ def kd_loss_function(
|
|||||||
|
|
||||||
# Slice student logits to match the teacher-provided sequence length
|
# Slice student logits to match the teacher-provided sequence length
|
||||||
student_logits_for_kd = student_logits[
|
student_logits_for_kd = student_logits[
|
||||||
:, -teacher_seq_len:, :
|
:, :teacher_seq_len, :
|
||||||
] # [B, teacher_seq_len, vocab_size]
|
] # [B, teacher_seq_len, vocab_size]
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
# Gather student logits for teacher's top-K tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user