pretrain: fix with sample_packing=false (#1841)

This commit is contained in:
Aman Gupta Karmani
2024-08-21 10:36:51 -07:00
committed by GitHub
parent 5aac4bc284
commit 649c19aba3
2 changed files with 3 additions and 3 deletions

View File

@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]:
res = tokenizer(
examples,
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,

View File

@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)