workaround for md5 variations (#533)
* workaround for md5 variations * refactor the prepared hash too
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
|
|||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
|
|
||||||
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
|
try:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||||
|
except TypeError:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
) -> DatasetDict:
|
) -> DatasetDict:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5( # nosec
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
)
|
||||||
).hexdigest()
|
)
|
||||||
)
|
)
|
||||||
prepared_ds_path = (
|
prepared_ds_path = (
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
@@ -374,7 +380,7 @@ def load_prepare_datasets(
|
|||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5( # nosec
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
@@ -385,8 +391,8 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
)
|
||||||
).hexdigest()
|
)
|
||||||
)
|
)
|
||||||
prepared_ds_path = (
|
prepared_ds_path = (
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
@@ -500,12 +506,8 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ str(cfg.seed or 42)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
train_fingerprint = hashlib.md5(
|
train_fingerprint = md5(to_hash_train)
|
||||||
to_hash_train.encode(), usedforsecurity=False
|
test_fingerprint = md5(to_hash_test)
|
||||||
).hexdigest()
|
|
||||||
test_fingerprint = hashlib.md5(
|
|
||||||
to_hash_test.encode(), usedforsecurity=False
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
|
|||||||
64
tests/test_data.py
Normal file
64
tests/test_data.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
test module for the axolotl.utis.data module
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.data import encode_pretraining, md5
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncodePretraining(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
test class for encode pretraining and md5 helper
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
self.tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.max_tokens = 15 # set a small number for easy inspection
|
||||||
|
|
||||||
|
def test_encode_pretraining(self):
|
||||||
|
examples = {
|
||||||
|
"text": [
|
||||||
|
"Hello, world!",
|
||||||
|
"Nice to meet you.",
|
||||||
|
"lorem ipsum dolor sit amet.",
|
||||||
|
"Nice to meet you again!.",
|
||||||
|
"hello, hello",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||||
|
|
||||||
|
self.assertEqual(len(result["input_ids"]), 3)
|
||||||
|
|
||||||
|
# Assert the length of input_ids and attention_mask is correct
|
||||||
|
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
|
||||||
|
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
|
||||||
|
|
||||||
|
# Assert EOS and PAD tokens are correctly added
|
||||||
|
# hello world! is 4 tokens
|
||||||
|
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
|
||||||
|
# second part, 5 tokens
|
||||||
|
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def test_md5(self):
|
||||||
|
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
||||||
|
self.assertEqual(
|
||||||
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user