set labels and fix datasets block

This commit is contained in:
Wing Lian
2024-12-13 13:04:24 -05:00
parent d000851eeb
commit 9eaae5925a

View File

@@ -41,6 +41,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
res["labels"] = res["input_ids"].copy()
return res
@@ -52,7 +53,7 @@ def load(tokenizer, cfg):
if cfg.pretraining_dataset:
cfg_ds = cfg.pretraining_dataset
else:
cfg_ds = cfg.dataset
cfg_ds = cfg.datasets
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,