From bb5e6f4b72759e3019268e329b95112259242cad Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Jan 2025 10:26:27 -0500 Subject: [PATCH] make sure to truncate logprobs if there are more than top_k --- src/axolotl/integrations/kd/chat_template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index cc70e12d2..bf2de93d6 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -74,7 +74,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): if input_padding_len < 0: # logprobs is longer than target_seq_len, # so we need to slice from the left/beginning of logprobs - logprobs = logprobs[:-input_seq_len] + # and truncate the second dimension of the logprobs to top_k + logprobs = logprobs[:-input_seq_len, :top_k] input_padding_len = 0 # target_seq_len = input_seq_len