From 6e409d2d88c229d02c34fc82d23001570540c1d6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 15:58:02 -0500 Subject: [PATCH] more info on preprocess for kd and fix import --- src/axolotl/integrations/kd/trainer.py | 3 ++- src/axolotl/utils/tokenization.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index ad68055c9..9d686299e 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -19,7 +19,8 @@ KD trainer import torch from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss + +from .topk_logprob.forward_kl import loss as topk_kd_loss class AxolotlKDTrainer(AxolotlTrainer): diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 139d50110..e0b21a9f0 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False): # Get the input_ids, labels, and attention_mask from the dataset input_ids = example["input_ids"] labels = example["labels"] + target_mask = example.pop("target_mask", None) # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 @@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False): delimiter = "" if text_only else " " LOG.info(delimiter.join(colored_tokens)) LOG.info("\n\n\n") + target_labels_count = sum(label_id != -100 for label_id in labels) + total_len = len(input_ids) + LOG.info(f"Total input len: {total_len}") + LOG.info(f"Count of labels: {target_labels_count}") + if target_mask: + target_mask_positions = sum(m[0] for m in target_mask) + LOG.info(f"Number of positions in target_mask: {target_mask_positions}") return " ".join(colored_tokens)