allow pretrain to be used with sft

This commit is contained in:
Wing Lian
2024-12-13 12:58:37 -05:00
parent effc4dc409
commit d000851eeb

View File

@@ -49,12 +49,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg): def load(tokenizer, cfg):
if cfg.pretraining_dataset:
cfg_ds = cfg.pretraining_dataset
else:
cfg_ds = cfg.dataset
strat = PretrainTokenizationStrategy( strat = PretrainTokenizationStrategy(
PretrainTokenizer(), PretrainTokenizer(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text", text_column=cfg_ds[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
return strat return strat