diff --git a/src/axolotl/prompt_strategies/pretrain.py b/src/axolotl/prompt_strategies/pretrain.py index 8430a7fca..ecf3c2d62 100644 --- a/src/axolotl/prompt_strategies/pretrain.py +++ b/src/axolotl/prompt_strategies/pretrain.py @@ -49,12 +49,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy): def load(tokenizer, cfg): + if cfg.pretraining_dataset: + cfg_ds = cfg.pretraining_dataset + else: + cfg_ds = cfg.dataset strat = PretrainTokenizationStrategy( PretrainTokenizer(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, - text_column=cfg.pretraining_dataset[0]["text_column"] or "text", + text_column=cfg_ds[0]["text_column"] or "text", max_length=cfg.sequence_len * 64, ) return strat