* Initial CLI implementation with click package * Adding fetch command for pulling examples and deepspeed configs * Automating default options for CliArgs classes * Mimicking existing no config behavior * bugfix in choose_config * Updating fetch to sync instead of re-download * bugfix * isort fix * fixing yaml isort order * pre-commit fixes * simplifying argument parsing -- pass through kwargs to do_cli * make accelerate launch default for non-preprocess commands * fixing arg handling * testing None placeholder approach * removing hacky --use-gpu argument to preprocess command * Adding brief README documentation for CLI * remove (New) * Initial CLI pytest tests * progress on CLI pytest * adding inference CLI tests; cleanup * Refactor train CLI tests to remove various mocking * Major CLI test refator; adding remaining CLI codepath test coverage * pytest fixes * remove integration markers * parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming * moving cli pytest due to isolation issues; cleanup * testing fixes; various minor improvements * fix * tests fix * Update tests/cli/conftest.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
"""
|
|
test module for the axolotl.utils.data module
|
|
"""
|
|
import unittest
|
|
|
|
from transformers import LlamaTokenizer
|
|
|
|
from axolotl.utils.data import encode_pretraining, md5
|
|
|
|
|
|
class TestEncodePretraining(unittest.TestCase):
|
|
"""
|
|
test class for encode pretraining and md5 helper
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
self.tokenizer.add_special_tokens(
|
|
{
|
|
"eos_token": "</s>",
|
|
"bos_token": "<s>",
|
|
"unk_token": "<unk>",
|
|
"pad_token": "<pad>",
|
|
}
|
|
)
|
|
self.max_tokens = 15 # set a small number for easy inspection
|
|
|
|
def test_encode_pretraining(self):
|
|
examples = {
|
|
"text": [
|
|
"Hello, world!",
|
|
"Nice to meet you.",
|
|
"lorem ipsum dolor sit amet.",
|
|
"Nice to meet you again!.",
|
|
"hello, hello",
|
|
]
|
|
}
|
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
|
|
|
self.assertEqual(len(result["input_ids"]), 3)
|
|
|
|
# Assert the length of input_ids and attention_mask is correct
|
|
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
|
|
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
|
|
|
|
# Assert EOS and PAD tokens are correctly added
|
|
# hello world! is 4 tokens
|
|
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
|
|
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
|
|
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
|
|
# second part, 5 tokens
|
|
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
|
|
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
|
|
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
|
|
|
def test_md5(self):
|
|
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
|
self.assertEqual(
|
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|