From d000851eeb69229f7d481e158c542f64bf768363 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Dec 2024 12:58:37 -0500 Subject: [PATCH] allow pretrain to be used with sft --- src/axolotl/prompt_strategies/pretrain.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/pretrain.py b/src/axolotl/prompt_strategies/pretrain.py index 8430a7fca..ecf3c2d62 100644 --- a/src/axolotl/prompt_strategies/pretrain.py +++ b/src/axolotl/prompt_strategies/pretrain.py @@ -49,12 +49,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy): def load(tokenizer, cfg): + if cfg.pretraining_dataset: + cfg_ds = cfg.pretraining_dataset + else: + cfg_ds = cfg.dataset strat = PretrainTokenizationStrategy( PretrainTokenizer(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, - text_column=cfg.pretraining_dataset[0]["text_column"] or "text", + text_column=cfg_ds[0]["text_column"] or "text", max_length=cfg.sequence_len * 64, ) return strat