allow pretrain to be used with sft
This commit is contained in:
@@ -49,12 +49,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
|
if cfg.pretraining_dataset:
|
||||||
|
cfg_ds = cfg.pretraining_dataset
|
||||||
|
else:
|
||||||
|
cfg_ds = cfg.dataset
|
||||||
strat = PretrainTokenizationStrategy(
|
strat = PretrainTokenizationStrategy(
|
||||||
PretrainTokenizer(),
|
PretrainTokenizer(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
|
text_column=cfg_ds[0]["text_column"] or "text",
|
||||||
max_length=cfg.sequence_len * 64,
|
max_length=cfg.sequence_len * 64,
|
||||||
)
|
)
|
||||||
return strat
|
return strat
|
||||||
|
|||||||
Reference in New Issue
Block a user