From b2774af66c64fe07e50d648c08b1446629f0da85 Mon Sep 17 00:00:00 2001 From: mashdragon <122402293+mashdragon@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:06:50 +0000 Subject: [PATCH] Take `split` param from config in all load_dataset instances (#2281) --- src/axolotl/utils/data/shared.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d14496d96..e4f31a184 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token): except (FileNotFoundError, ConnectionError): pass + # gather extra args from the config + load_ds_kwargs = {} + if config_dataset.split: + load_ds_kwargs["split"] = config_dataset.split + else: + load_ds_kwargs["split"] = None + # prefer local dataset, even if hub exists local_path = Path(config_dataset.path) if local_path.exists(): @@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.data_files, streaming=False, - split=None, + **load_ds_kwargs, ) else: try: @@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token): config_dataset.path, name=config_dataset.name, streaming=False, - split=None, + **load_ds_kwargs, ) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) @@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, + **load_ds_kwargs, ) else: raise ValueError( "unhandled dataset load: local path exists, but is neither a directory or a file" ) elif ds_from_hub: - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split ds = load_dataset( config_dataset.path, name=config_dataset.name, @@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) else: if isinstance(config_dataset.data_files, str): @@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=fp, streaming=False, - split=None, + **load_ds_kwargs, ) if not ds: raise ValueError("unhandled dataset load")