fix finding the top-k rather than assuming first position has the correct val

This commit is contained in:
Wing Lian
2025-01-21 13:09:20 -05:00
parent 67c1c8405e
commit 4e4a16cd8a

View File

@@ -66,7 +66,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
# get non-zero top-k
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
top_k = max(set(top_k_vals), key=top_k_vals.count)
target_logprobs = []
target_token_ids = []
target_mask = []