From 5eae1341101b8551f3051cdc8a5329c022b116ca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 10 Jan 2025 04:04:13 +0700 Subject: [PATCH] feat: add support for data_files in pretraining (#2238) --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/data/sft.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 77597ae1a..1bca08396 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel): text_column: Optional[str] = "text" type: Optional[str] = "pretrain" trust_remote_code: Optional[bool] = False + data_files: Optional[str] = None class UserDefinedPrompterType(BaseModel): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 3e784ca3e..cfc40406e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -88,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): path = cfg.pretraining_dataset split = "train" name = None + data_files = None if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): @@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None): if "split" in cfg.pretraining_dataset[0]: split = cfg.pretraining_dataset[0]["split"] + data_files = cfg.pretraining_dataset[0].get("data_files") + ds_wrapper_partial = functools.partial( get_dataset_wrapper, cfg.pretraining_dataset[0], @@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): ) train_dataset = wrap_pretraining_dataset( - load_dataset(path, streaming=True, split=split, name=name), + load_dataset( + path, streaming=True, split=split, name=name, data_files=data_files + ), tokenizer, cfg, ds_wrapper_partial,