Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511)
This commit is contained in:
@@ -246,9 +246,14 @@ def load_datasets(
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
|
||||
@@ -21,6 +21,8 @@ class TrainerCliArgs:
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
|
||||
@@ -8,13 +8,13 @@ from termcolor import colored
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer):
|
||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(5):
|
||||
check_example_labels(dataset[idx], tokenizer)
|
||||
for idx in range(num_examples):
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
|
||||
|
||||
def check_example_labels(example, tokenizer):
|
||||
def check_example_labels(example, tokenizer, text_only=False):
|
||||
# Get the input_ids, labels, and attention_mask from the dataset
|
||||
input_ids = example["input_ids"]
|
||||
labels = example["labels"]
|
||||
@@ -29,8 +29,10 @@ def check_example_labels(example, tokenizer):
|
||||
decoded_input_token = tokenizer.decode(input_id)
|
||||
# Choose the color based on whether the label has the ignore value or not
|
||||
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
||||
colored_token = colored(decoded_input_token, color) + colored(
|
||||
f"({label_id}, {mask}, {input_id})", "white"
|
||||
colored_token = colored(decoded_input_token, color) + (
|
||||
not text_only
|
||||
and colored(f"({label_id}, {mask}, {input_id})", "white")
|
||||
or ""
|
||||
)
|
||||
colored_tokens.append(colored_token)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user