Add debug option for RL dataset preprocessing (#1404)
* adding debug option for RL dataset preprocessing * Refine formatting of debugging code in RL dataset preprocessing * Update __init__.py * chore: fix lint --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -433,6 +433,23 @@ def load_rl_datasets(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
@@ -10,10 +9,19 @@ from termcolor import colored
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
||||
def check_dataset_labels(
|
||||
dataset,
|
||||
tokenizer,
|
||||
num_examples=5,
|
||||
text_only=False,
|
||||
rl_mode=False,
|
||||
):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(num_examples):
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
if not rl_mode:
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
else:
|
||||
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
|
||||
|
||||
def check_example_labels(example, tokenizer, text_only=False):
|
||||
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||
"""Helper function to color tokens based on their type."""
|
||||
colored_text = colored(decoded_token, color)
|
||||
return (
|
||||
colored_text
|
||||
if text_only
|
||||
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||
)
|
||||
|
||||
|
||||
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||
"""Helper function to process and color tokens."""
|
||||
colored_tokens = [
|
||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||
for token in tokenizer.encode(tokens)
|
||||
]
|
||||
return colored_tokens
|
||||
|
||||
|
||||
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||
|
||||
input_tokens = example[field_prompt]
|
||||
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||
|
||||
# Process and color each type of token
|
||||
colored_tokens = process_tokens_for_rl_debug(
|
||||
input_tokens, "yellow", tokenizer, text_only
|
||||
)
|
||||
colored_chosens = process_tokens_for_rl_debug(
|
||||
labels_chosen, "green", tokenizer, text_only
|
||||
)
|
||||
colored_rejecteds = process_tokens_for_rl_debug(
|
||||
labels_rejected, "red", tokenizer, text_only
|
||||
)
|
||||
|
||||
# Create a delimiter based on text_only flag
|
||||
delimiter = "" if text_only else " "
|
||||
|
||||
# Logging information
|
||||
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
|
||||
return delimiter.join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
|
||||
Reference in New Issue
Block a user