fix sharegpt tokenization, refactor tokenization debugging
This commit is contained in:
@@ -11,6 +11,8 @@ import yaml
|
||||
from attrdict import AttrDefault
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -42,36 +44,6 @@ def choose_device(cfg):
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer):
|
||||
from termcolor import colored
|
||||
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(5):
|
||||
# Get the input_ids, labels, and attention_mask from the dataset
|
||||
input_ids = dataset[idx]["input_ids"]
|
||||
labels = dataset[idx]["labels"]
|
||||
attention_mask = dataset[idx]["attention_mask"]
|
||||
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
colored_tokens = []
|
||||
for i, (input_id, label_id, mask) in enumerate(
|
||||
zip(input_ids, labels, attention_mask)
|
||||
):
|
||||
decoded_input_token = tokenizer.decode(input_id)
|
||||
# Choose the color based on whether the label has the ignore value or not
|
||||
color = (
|
||||
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
||||
)
|
||||
colored_token = colored(decoded_input_token, color) + colored(
|
||||
f"({label_id}, {mask})", "white"
|
||||
)
|
||||
colored_tokens.append(colored_token)
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer):
|
||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
@@ -199,8 +171,9 @@ def train(
|
||||
return
|
||||
|
||||
if cfg.debug:
|
||||
logging.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select([random.randrange(0, len(train_dataset) - 1)]),
|
||||
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user