feat: add support for data_files in pretraining (#2238)

This commit is contained in:
NanoCode012
2025-01-10 04:04:13 +07:00
committed by GitHub
parent 7669a03fb4
commit ed77e7001e
2 changed files with 7 additions and 1 deletions

View File

@@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text" text_column: Optional[str] = "text"
type: Optional[str] = "pretrain" type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
class UserDefinedPrompterType(BaseModel): class UserDefinedPrompterType(BaseModel):

View File

@@ -88,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
split = "train" split = "train"
name = None name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance( if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict cfg.pretraining_dataset[0], dict
): ):
@@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]: if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"] split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name), load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,