diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 2f17f0a9b..671f8e5c8 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -66,7 +66,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) input_padding_len = input_seq_len - target_seq_len - top_k = len(logprobs[0]) + # get non-zero top-k + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + top_k = max(set(top_k_vals), key=top_k_vals.count) target_logprobs = [] target_token_ids = [] target_mask = []