diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 20d0fcfb8..f322b800b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,7 +2,6 @@ import functools import hashlib import logging -from hashlib import md5 from pathlib import Path from typing import Tuple, Union @@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl") DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" +def md5(to_hash: str, encoding: str = "utf-8") -> str: + try: + return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() + except TypeError: + return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec + + def prepare_dataset(cfg, tokenizer): if not cfg.pretraining_dataset: with zero_first(is_main_process()): @@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets( ) -> DatasetDict: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( - md5( # nosec + md5( ( str(cfg.sequence_len) + "@" @@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets( ) + "|" + tokenizer_name - ).encode("utf-8") - ).hexdigest() + ) + ) ) prepared_ds_path = ( Path(cfg.dataset_prepared_path) / ds_hash @@ -374,7 +380,7 @@ def load_prepare_datasets( # see if we can go ahead and load the stacked dataset seed = f"@{str(cfg.seed)}" if cfg.seed else "" ds_hash = str( - md5( # nosec + md5( ( str(cfg.sequence_len) + "@" @@ -385,8 +391,8 @@ def load_prepare_datasets( ) + "|" + tokenizer_name - ).encode("utf-8") - ).hexdigest() + ) + ) ) prepared_ds_path = ( Path(cfg.dataset_prepared_path) / ds_hash @@ -500,12 +506,8 @@ def load_prepare_datasets( + "|" + str(cfg.seed or 42) ) - train_fingerprint = hashlib.md5( - to_hash_train.encode(), usedforsecurity=False - ).hexdigest() - test_fingerprint = hashlib.md5( - to_hash_test.encode(), usedforsecurity=False - ).hexdigest() + train_fingerprint = md5(to_hash_train) + test_fingerprint = md5(to_hash_test) with zero_first(is_main_process()): dataset = dataset.train_test_split( diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 000000000..9d7f5a041 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,64 @@ +""" +test module for the axolotl.utis.data module +""" +import unittest + +from transformers import LlamaTokenizer + +from axolotl.utils.data import encode_pretraining, md5 + + +class TestEncodePretraining(unittest.TestCase): + """ + test class for encode pretraining and md5 helper + """ + + def setUp(self): + self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + self.max_tokens = 15 # set a small number for easy inspection + + def test_encode_pretraining(self): + examples = { + "text": [ + "Hello, world!", + "Nice to meet you.", + "lorem ipsum dolor sit amet.", + "Nice to meet you again!.", + "hello, hello", + ] + } + result = encode_pretraining(self.tokenizer, self.max_tokens, examples) + + self.assertEqual(len(result["input_ids"]), 3) + + # Assert the length of input_ids and attention_mask is correct + self.assertEqual(len(result["input_ids"][0]), self.max_tokens) + self.assertEqual(len(result["attention_mask"][0]), self.max_tokens) + + # Assert EOS and PAD tokens are correctly added + # hello world! is 4 tokens + self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id) + self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id) + self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id) + # second part, 5 tokens + self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id) + self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id) + self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id) + + def test_md5(self): + self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3") + self.assertEqual( + md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3" + ) + + +if __name__ == "__main__": + unittest.main()