From 4e4a16cd8a58e0705f46f325a00532c8e434cc8e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Jan 2025 13:09:20 -0500 Subject: [PATCH] fix finding the top-k rather than assuming first position has the correct val --- src/axolotl/integrations/kd/chat_template.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 2f17f0a9b..671f8e5c8 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -66,7 +66,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) input_padding_len = input_seq_len - target_seq_len - top_k = len(logprobs[0]) + # get non-zero top-k + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + top_k = max(set(top_k_vals), key=top_k_vals.count) target_logprobs = [] target_token_ids = [] target_mask = []