make sure to truncate logprobs if there are more than top_k

This commit is contained in:
Wing Lian
2025-01-21 10:26:27 -05:00
parent 32258c247e
commit bb5e6f4b72

View File

@@ -74,7 +74,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
# and truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:-input_seq_len, :top_k]
input_padding_len = 0
# target_seq_len = input_seq_len