more info on preprocess for kd and fix import
This commit is contained in:
@@ -19,7 +19,8 @@ KD trainer
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
@@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
# Get the input_ids, labels, and attention_mask from the dataset
|
||||
input_ids = example["input_ids"]
|
||||
labels = example["labels"]
|
||||
target_mask = example.pop("target_mask", None)
|
||||
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
@@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
delimiter = "" if text_only else " "
|
||||
LOG.info(delimiter.join(colored_tokens))
|
||||
LOG.info("\n\n\n")
|
||||
target_labels_count = sum(label_id != -100 for label_id in labels)
|
||||
total_len = len(input_ids)
|
||||
LOG.info(f"Total input len: {total_len}")
|
||||
LOG.info(f"Count of labels: {target_labels_count}")
|
||||
if target_mask:
|
||||
target_mask_positions = sum(m[0] for m in target_mask)
|
||||
LOG.info(f"Number of positions in target_mask: {target_mask_positions}")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user