From 510cf453179ca702520752fc0bf40b6ac1032cc6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Jan 2025 20:10:05 -0500 Subject: [PATCH] improve logprob masking and shift in trainer --- src/axolotl/integrations/kd/chat_template.py | 32 +++++++++++++++++--- src/axolotl/integrations/kd/trainer.py | 2 +- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index b03e059b9..32f5d0ce4 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): ) def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + logprobs = sample.pop(self.logprobs_field) target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) @@ -62,9 +66,20 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_token_ids = [] target_mask = [] + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + target_seq_len = input_seq_len + # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - for _ in range(1, input_padding_len): # start at 1 since this is causal + + # for causal models, if we start the range at 1, then we don't need to shift in the trainer + # otherwise, we need to shift in the trainer + shift = 0 + for _ in range(shift, input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) @@ -73,6 +88,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # TODO also check against sample["labels"] target_mask.append([1] * top_k) + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + for _, token_pos_logprobs in enumerate(logprobs): # Initialize collections for logprobs and token_ids position_logprobs = [] @@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) - # since we started at index 1 for causal, we need one more padding token - target_logprobs.append([-float("inf")] * top_k) - target_token_ids.append(list(range(top_k))) - target_mask.append([0] * top_k) + if shift == 1: + # since we started at index 1 for causal, we need one more padding token + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 7eda30659..1aa1df452 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer): inputs, return_outputs=False, num_items_in_batch=None, - shift_targets=False, + shift_targets=True, ): """ How the loss is computed by Trainer. By default, all models return the loss in the first element.