Add debug option for RL dataset preprocessing (#1404)
* adding debug option for RL dataset preprocessing * Refine formatting of debugging code in RL dataset preprocessing * Update __init__.py * chore: fix lint --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -433,6 +433,23 @@ def load_rl_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cli_args.debug or cfg.debug:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[
|
||||||
|
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||||
|
for _ in range(cli_args.debug_num_examples)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
num_examples=cli_args.debug_num_examples,
|
||||||
|
text_only=cli_args.debug_text_only,
|
||||||
|
rl_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -10,10 +9,19 @@ from termcolor import colored
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
def check_dataset_labels(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
num_examples=5,
|
||||||
|
text_only=False,
|
||||||
|
rl_mode=False,
|
||||||
|
):
|
||||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||||
for idx in range(num_examples):
|
for idx in range(num_examples):
|
||||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
if not rl_mode:
|
||||||
|
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
else:
|
||||||
|
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
|
||||||
|
|
||||||
def check_example_labels(example, tokenizer, text_only=False):
|
def check_example_labels(example, tokenizer, text_only=False):
|
||||||
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||||
|
"""Helper function to color tokens based on their type."""
|
||||||
|
colored_text = colored(decoded_token, color)
|
||||||
|
return (
|
||||||
|
colored_text
|
||||||
|
if text_only
|
||||||
|
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||||
|
"""Helper function to process and color tokens."""
|
||||||
|
colored_tokens = [
|
||||||
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
|
for token in tokenizer.encode(tokens)
|
||||||
|
]
|
||||||
|
return colored_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||||
|
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||||
|
|
||||||
|
input_tokens = example[field_prompt]
|
||||||
|
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||||
|
|
||||||
|
# Process and color each type of token
|
||||||
|
colored_tokens = process_tokens_for_rl_debug(
|
||||||
|
input_tokens, "yellow", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_chosens = process_tokens_for_rl_debug(
|
||||||
|
labels_chosen, "green", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_rejecteds = process_tokens_for_rl_debug(
|
||||||
|
labels_rejected, "red", tokenizer, text_only
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a delimiter based on text_only flag
|
||||||
|
delimiter = "" if text_only else " "
|
||||||
|
|
||||||
|
# Logging information
|
||||||
|
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||||
|
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||||
|
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||||
|
|
||||||
|
return delimiter.join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||||
"SYSTEM": "system",
|
"SYSTEM": "system",
|
||||||
|
|||||||
Reference in New Issue
Block a user