Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511)

This commit is contained in:
Tom Jobbins
2023-08-31 22:26:52 +01:00
committed by GitHub
parent 396a7a74fc
commit 48434bec54
3 changed files with 16 additions and 7 deletions

View File

@@ -246,9 +246,14 @@ def load_datasets(
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec [
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
), ),
tokenizer, tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
) )
return TrainDatasetMeta( return TrainDatasetMeta(

View File

@@ -21,6 +21,8 @@ class TrainerCliArgs:
""" """
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False) prepare_ds_only: bool = field(default=False)

View File

@@ -8,13 +8,13 @@ from termcolor import colored
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer): def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements # the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5): for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer) check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer): def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset # Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"] input_ids = example["input_ids"]
labels = example["labels"] labels = example["labels"]
@@ -29,8 +29,10 @@ def check_example_labels(example, tokenizer):
decoded_input_token = tokenizer.decode(input_id) decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not # Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored( colored_token = colored(decoded_input_token, color) + (
f"({label_id}, {mask}, {input_id})", "white" not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)