change up logic so we always truncate to top_k

This commit is contained in:
Wing Lian
2025-01-21 11:20:01 -05:00
parent bb5e6f4b72
commit bded6df509

View File

@@ -74,11 +74,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
# and truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:-input_seq_len, :top_k]
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:, :top_k]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf