pretrain: fix with sample_packing=false (#1841)

This commit is contained in:
Aman Gupta Karmani
2024-08-21 10:36:51 -07:00
committed by GitHub
parent 5aac4bc284
commit 649c19aba3
2 changed files with 3 additions and 3 deletions

View File

@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples, examples["text"],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,

View File

@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello", "hello, hello",
] ]
} }
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3) self.assertEqual(len(result["input_ids"]), 3)