feat: add support for data_files in pretraining (#2238)
This commit is contained in:
@@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
|
||||
text_column: Optional[str] = "text"
|
||||
type: Optional[str] = "pretrain"
|
||||
trust_remote_code: Optional[bool] = False
|
||||
data_files: Optional[str] = None
|
||||
|
||||
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
|
||||
@@ -88,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
path = cfg.pretraining_dataset
|
||||
split = "train"
|
||||
name = None
|
||||
data_files = None
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
@@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
if "split" in cfg.pretraining_dataset[0]:
|
||||
split = cfg.pretraining_dataset[0]["split"]
|
||||
|
||||
data_files = cfg.pretraining_dataset[0].get("data_files")
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
@@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
)
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split=split, name=name),
|
||||
load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
),
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
|
||||
Reference in New Issue
Block a user