Files
axolotl/tests/test_data.py
Wing Lian 05f03b541a hf offline decorator for tests to workaround rate limits (#2452) [skip ci]
* hf offline decorator for tests to workaround rate limits

* fail quicker so we can see logs

* try new cache name

* limit files downloaded

* phi mini predownload

* offline decorator for phi tokenizer

* handle meta llama 8b offline too

* make sure to return fixtures if they are wrapped too

* more fixes

* more things offline

* more offline things

* fix the env var

* fix the model name

* handle gemma also

* force reload of modules to recheck offline status

* prefetch mistral too

* use reset_sessions so hub picks up offline mode

* more fixes

* rename so it doesn't seem like a context manager

* fix backoff

* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline

* include additional dataset

* more fixes

* more fixes

* replace tiny shakespeaere dataset

* skip some tests for now

* use more robust check using snapshot download to determine if a dataset name is on the hub

* typo for skip reason

* use local_files_only

* more fixtures

* remove local only

* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached

* make sure fixtures aren't offline

improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache

* now run all the ci again with hopefully a correct cache

* skip wonky tests for now

* skip wonky tests for now

* handle offline mode for model card creation
2025-03-28 19:20:46 -04:00

68 lines
2.2 KiB
Python

"""
test module for the axolotl.utils.data module
"""
import unittest
from transformers import LlamaTokenizer
from utils import enable_hf_offline
from axolotl.utils.data import encode_pretraining, md5
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
@enable_hf_offline
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection
def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)
# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
if __name__ == "__main__":
unittest.main()