change up logic so we always truncate to top_k
This commit is contained in:
@@ -74,11 +74,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
if input_padding_len < 0:
|
if input_padding_len < 0:
|
||||||
# logprobs is longer than target_seq_len,
|
# logprobs is longer than target_seq_len,
|
||||||
# so we need to slice from the left/beginning of logprobs
|
# so we need to slice from the left/beginning of logprobs
|
||||||
# and truncate the second dimension of the logprobs to top_k
|
logprobs = logprobs[:-input_seq_len]
|
||||||
logprobs = logprobs[:-input_seq_len, :top_k]
|
|
||||||
input_padding_len = 0
|
input_padding_len = 0
|
||||||
# target_seq_len = input_seq_len
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
|
# truncate the second dimension of the logprobs to top_k
|
||||||
|
logprobs = logprobs[:, :top_k]
|
||||||
|
|
||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user