more info on preprocess for kd and fix import

This commit is contained in:
Wing Lian
2024-12-30 15:58:02 -05:00
parent d5bc214300
commit 6e409d2d88
2 changed files with 10 additions and 1 deletions

View File

@@ -19,7 +19,8 @@ KD trainer
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import loss as topk_kd_loss
class AxolotlKDTrainer(AxolotlTrainer):

View File

@@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
target_mask = example.pop("target_mask", None)
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
target_labels_count = sum(label_id != -100 for label_id in labels)
total_len = len(input_ids)
LOG.info(f"Total input len: {total_len}")
LOG.info(f"Count of labels: {target_labels_count}")
if target_mask:
target_mask_positions = sum(m[0] for m in target_mask)
LOG.info(f"Number of positions in target_mask: {target_mask_positions}")
return " ".join(colored_tokens)