From 1eebbd09c39556c6ead0ed3c2e3a234082911d5a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Sep 2023 08:09:56 -0400 Subject: [PATCH] improve handling for empty text on the tokenization step (#502) --- src/axolotl/prompt_tokenizers.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index b1aaeb350..f30d0e383 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -6,7 +6,7 @@ import functools import logging from typing import Dict, List, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import IGNORE_TOKEN_ID @@ -66,14 +66,21 @@ class PromptTokenizingStrategy(abc.ABC): pass return False - def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) + def _tokenize( + self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False + ) -> BatchEncoding: + result: BatchEncoding + if not prompt.strip(): + LOG.warning("Empty text requested for tokenization.") + result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + else: + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") if (