From 5159d00a86ef7c358aa819d0bafadd1d5d8304e8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 30 Apr 2023 00:23:53 -0400 Subject: [PATCH] fix sharegpt tokenization, refactor tokenization debugging --- scripts/finetune.py | 35 ++++--------------------------- src/axolotl/prompters.py | 21 ++++++++++++++----- src/axolotl/utils/models.py | 10 ++++----- src/axolotl/utils/tokenization.py | 33 +++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 5 +++++ 5 files changed, 63 insertions(+), 41 deletions(-) create mode 100644 src/axolotl/utils/tokenization.py diff --git a/scripts/finetune.py b/scripts/finetune.py index 858f33f9a..cd8c6f650 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -11,6 +11,8 @@ import yaml from attrdict import AttrDefault # add src to the pythonpath so we don't need to pip install this +from axolotl.utils.tokenization import check_dataset_labels + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) @@ -42,36 +44,6 @@ def choose_device(cfg): cfg.device_map = {"": cfg.device} -def check_dataset_labels(dataset, tokenizer): - from termcolor import colored - - # the dataset is already shuffled, so let's just check the first 5 elements - for idx in range(5): - # Get the input_ids, labels, and attention_mask from the dataset - input_ids = dataset[idx]["input_ids"] - labels = dataset[idx]["labels"] - attention_mask = dataset[idx]["attention_mask"] - - # You can compare the input_ids and labels element-wise - # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 - colored_tokens = [] - for i, (input_id, label_id, mask) in enumerate( - zip(input_ids, labels, attention_mask) - ): - decoded_input_token = tokenizer.decode(input_id) - # Choose the color based on whether the label has the ignore value or not - color = ( - "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") - ) - colored_token = colored(decoded_input_token, color) + colored( - f"({label_id}, {mask})", "white" - ) - colored_tokens.append(colored_token) - - logging.info(" ".join(colored_tokens)) - logging.info("\n\n\n") - - def do_inference(cfg, model, tokenizer): tokenizer.add_special_tokens({"unk_token": ""}) tokenizer.add_special_tokens({"bos_token": ""}) @@ -199,8 +171,9 @@ def train( return if cfg.debug: + logging.info("check_dataset_labels...") check_dataset_labels( - train_dataset.select([random.randrange(0, len(train_dataset) - 1)]), + train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]), tokenizer, ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 903ee4385..c2acf60c3 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -127,7 +127,7 @@ conv_vicuna_v1_1 = Conversation( class ShareGPTPrompter: - def build_prompt(self, source, tokenizer): + def build_prompt(self, source, tokenizer, sequence_len=2048): # ignore the system prompt if provided if source[0]["from"] == "system": source.pop(0) @@ -157,13 +157,14 @@ class ShareGPTPrompter: role = roles[sentence["from"]] assert role == conv.roles[j % 2] conv.append_message(role, sentence["value"]) + # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up conversation = conv.get_prompt() # Tokenize conversations tokenized_result = tokenizer( conversation, truncation=True, - max_length=2048, # FIXME + max_length=sequence_len, # FIXME padding=False, return_tensors=None, ) @@ -173,7 +174,9 @@ class ShareGPTPrompter: sep = conv.sep + conv.roles[1] + ": " rounds = conversation.split(conv.sep2) + rounds = [r + conv.sep2 for r in rounds] cur_len = 1 + target[0] = IGNORE_TOKEN_ID # mask out the bos for i, rou in enumerate(rounds): if rou == "": break @@ -182,19 +185,27 @@ class ShareGPTPrompter: if len(parts) != 2: break parts[0] += sep - round_len = len(tokenizer(rou)["input_ids"]) - instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2 + round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this + # we have to strip the initial part, any dangling whitespace creates an additional ghost token + instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this target[cur_len : cur_len + instruction_len] = [ IGNORE_TOKEN_ID ] * instruction_len cur_len += round_len - target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) + if cur_len >= sequence_len: + break + + # Fix: Truncate the target to have the same length as input_ids + target = target[:len(tokenized_result["input_ids"])] + # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) + attention_mask = [ 1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"] ] + # TODO truncate len to sequence_len return dict( input_ids=tokenized_result["input_ids"], labels=target, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a14f89b19..ce85a47eb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -53,7 +53,7 @@ def load_model( logging.info("patching with xformers attention") hijack_llama_attention() - torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) + torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 try: if cfg.load_4bit: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( @@ -161,11 +161,11 @@ def load_model( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) os.environ["TOKENIZERS_PARALLELISM"] = "false" - if cfg.special_tokens: - for k, v in cfg.special_tokens.items(): - setattr(tokenizer, k, v) + if cfg.tokens: + for k, v in cfg.tokens.items(): + tokenizer.add_special_tokens({k: v}) - if load_in_8bit and not cfg.load_4bit: + if load_in_8bit and cfg.load_4bit: logging.info("converting model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py new file mode 100644 index 000000000..b9ffb1e1b --- /dev/null +++ b/src/axolotl/utils/tokenization.py @@ -0,0 +1,33 @@ +from termcolor import colored +import logging + +def check_dataset_labels(dataset, tokenizer): + # the dataset is already shuffled, so let's just check the first 5 elements + for idx in range(5): + check_example_labels(dataset[idx], tokenizer) + + +def check_example_labels(example, tokenizer): + # Get the input_ids, labels, and attention_mask from the dataset + input_ids = example["input_ids"] + labels = example["labels"] + attention_mask =example["attention_mask"] + + # You can compare the input_ids and labels element-wise + # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 + colored_tokens = [] + for i, (input_id, label_id, mask) in enumerate( + zip(input_ids, labels, attention_mask) + ): + decoded_input_token = tokenizer.decode(input_id) + # Choose the color based on whether the label has the ignore value or not + color = ( + "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") + ) + colored_token = colored(decoded_input_token, color) + colored( + f"({label_id}, {mask}, {input_id})", "white" + ) + colored_tokens.append(colored_token) + + logging.info(" ".join(colored_tokens)) + logging.info("\n\n\n") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8ce05ba12..99ea101f5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -61,6 +61,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, + optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer, + lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None, + weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, + fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None, + fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None, **training_arguments_kwargs, )