workaround for md5 variations (#533)

* workaround for md5 variations

* refactor the prepared hash too
This commit is contained in:
Wing Lian
2023-09-08 16:01:05 -04:00
committed by GitHub
parent 78ee2cdab2
commit 0b4cf5bc8c
2 changed files with 79 additions and 13 deletions

View File

@@ -2,7 +2,6 @@
import functools import functools
import hashlib import hashlib
import logging import logging
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( # nosec md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") )
).hexdigest() )
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -374,7 +380,7 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str( ds_hash = str(
md5( # nosec md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
@@ -385,8 +391,8 @@ def load_prepare_datasets(
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") )
).hexdigest() )
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -500,12 +506,8 @@ def load_prepare_datasets(
+ "|" + "|"
+ str(cfg.seed or 42) + str(cfg.seed or 42)
) )
train_fingerprint = hashlib.md5( train_fingerprint = md5(to_hash_train)
to_hash_train.encode(), usedforsecurity=False test_fingerprint = md5(to_hash_test)
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with zero_first(is_main_process()): with zero_first(is_main_process()):
dataset = dataset.train_test_split( dataset = dataset.train_test_split(

64
tests/test_data.py Normal file
View File

@@ -0,0 +1,64 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection
def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)
# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
if __name__ == "__main__":
unittest.main()