fix finding the top-k rather than assuming first position has the correct val
This commit is contained in:
@@ -66,7 +66,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
top_k = len(logprobs[0])
|
||||
# get non-zero top-k
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
Reference in New Issue
Block a user