pretrain: fix with sample_packing=false (#1841)
This commit is contained in:
committed by
GitHub
parent
5aac4bc284
commit
649c19aba3
@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples,
|
examples["text"],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_tokens - 2,
|
max_length=max_tokens - 2,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
"hello, hello",
|
"hello, hello",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||||
|
|
||||||
self.assertEqual(len(result["input_ids"]), 3)
|
self.assertEqual(len(result["input_ids"]), 3)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user