From 48434bec54cb44373cbeafa787f738c48f76cdba Mon Sep 17 00:00:00 2001 From: Tom Jobbins <784313+TheBloke@users.noreply.github.com> Date: Thu, 31 Aug 2023 22:26:52 +0100 Subject: [PATCH] Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511) --- scripts/finetune.py | 7 ++++++- src/axolotl/common/cli.py | 2 ++ src/axolotl/utils/tokenization.py | 14 ++++++++------ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 201a47e14..0a5f31863 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -246,9 +246,14 @@ def load_datasets( LOG.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] ), tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, ) return TrainDatasetMeta( diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index f5bd9b037..62f2b1061 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -21,6 +21,8 @@ class TrainerCliArgs: """ debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=5) inference: bool = field(default=False) merge_lora: bool = field(default=False) prepare_ds_only: bool = field(default=False) diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index b2d1df400..82fcbc638 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -8,13 +8,13 @@ from termcolor import colored LOG = logging.getLogger("axolotl") -def check_dataset_labels(dataset, tokenizer): +def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): # the dataset is already shuffled, so let's just check the first 5 elements - for idx in range(5): - check_example_labels(dataset[idx], tokenizer) + for idx in range(num_examples): + check_example_labels(dataset[idx], tokenizer, text_only=text_only) -def check_example_labels(example, tokenizer): +def check_example_labels(example, tokenizer, text_only=False): # Get the input_ids, labels, and attention_mask from the dataset input_ids = example["input_ids"] labels = example["labels"] @@ -29,8 +29,10 @@ def check_example_labels(example, tokenizer): decoded_input_token = tokenizer.decode(input_id) # Choose the color based on whether the label has the ignore value or not color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") - colored_token = colored(decoded_input_token, color) + colored( - f"({label_id}, {mask}, {input_id})", "white" + colored_token = colored(decoded_input_token, color) + ( + not text_only + and colored(f"({label_id}, {mask}, {input_id})", "white") + or "" ) colored_tokens.append(colored_token)