From 649c19aba31c022028bb508c1b945da9fe407e94 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Wed, 21 Aug 2024 10:36:51 -0700 Subject: [PATCH] pretrain: fix with sample_packing=false (#1841) --- src/axolotl/utils/data/pretraining.py | 4 ++-- tests/test_data.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index e056c7f50..16f38218c 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl") def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] ) -> Dict[str, List]: res = tokenizer( - examples, + examples["text"], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, diff --git a/tests/test_data.py b/tests/test_data.py index 16af089a0..9d7f5a041 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) + result = encode_pretraining(self.tokenizer, self.max_tokens, examples) self.assertEqual(len(result["input_ids"]), 3)