make sure to truncate logprobs if there are more than top_k
This commit is contained in:
@@ -74,7 +74,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
# and truncate the second dimension of the logprobs to top_k
|
||||
logprobs = logprobs[:-input_seq_len, :top_k]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user