From bded6df509d2e87a0da81d5226039c91e2d10560 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Jan 2025 11:20:01 -0500 Subject: [PATCH] change up logic so we always truncate to top_k --- src/axolotl/integrations/kd/chat_template.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index bf2de93d6..5a0c4e90b 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -74,11 +74,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): if input_padding_len < 0: # logprobs is longer than target_seq_len, # so we need to slice from the left/beginning of logprobs - # and truncate the second dimension of the logprobs to top_k - logprobs = logprobs[:-input_seq_len, :top_k] + logprobs = logprobs[:-input_seq_len] input_padding_len = 0 # target_seq_len = input_seq_len + # truncate the second dimension of the logprobs to top_k + logprobs = logprobs[:, :top_k] + # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf