change up logic so we always truncate to top_k

This commit is contained in:
Wing Lian
2025-01-21 11:20:01 -05:00
parent bb5e6f4b72
commit bded6df509

View File

@@ -74,11 +74,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
if input_padding_len < 0: if input_padding_len < 0:
# logprobs is longer than target_seq_len, # logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs # so we need to slice from the left/beginning of logprobs
# and truncate the second dimension of the logprobs to top_k logprobs = logprobs[:-input_seq_len]
logprobs = logprobs[:-input_seq_len, :top_k]
input_padding_len = 0 input_padding_len = 0
# target_seq_len = input_seq_len # target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:, :top_k]
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf