fix sharegpt tokenization, refactor tokenization debugging

This commit is contained in:
Wing Lian
2023-04-30 00:23:53 -04:00
parent c0f50d9c61
commit 5159d00a86
5 changed files with 63 additions and 41 deletions

View File

@@ -11,6 +11,8 @@ import yaml
from attrdict import AttrDefault
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -42,36 +44,6 @@ def choose_device(cfg):
cfg.device_map = {"": cfg.device}
def check_dataset_labels(dataset, tokenizer):
from termcolor import colored
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = dataset[idx]["input_ids"]
labels = dataset[idx]["labels"]
attention_mask = dataset[idx]["attention_mask"]
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
for i, (input_id, label_id, mask) in enumerate(
zip(input_ids, labels, attention_mask)
):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = (
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
)
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask})", "white"
)
colored_tokens.append(colored_token)
logging.info(" ".join(colored_tokens))
logging.info("\n\n\n")
def do_inference(cfg, model, tokenizer):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
@@ -199,8 +171,9 @@ def train(
return
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select([random.randrange(0, len(train_dataset) - 1)]),
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
tokenizer,
)