set labels and fix datasets block
This commit is contained in:
@@ -41,6 +41,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|||||||
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
||||||
]
|
]
|
||||||
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
||||||
|
res["labels"] = res["input_ids"].copy()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@@ -52,7 +53,7 @@ def load(tokenizer, cfg):
|
|||||||
if cfg.pretraining_dataset:
|
if cfg.pretraining_dataset:
|
||||||
cfg_ds = cfg.pretraining_dataset
|
cfg_ds = cfg.pretraining_dataset
|
||||||
else:
|
else:
|
||||||
cfg_ds = cfg.dataset
|
cfg_ds = cfg.datasets
|
||||||
strat = PretrainTokenizationStrategy(
|
strat = PretrainTokenizationStrategy(
|
||||||
PretrainTokenizer(),
|
PretrainTokenizer(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user