diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d14496d96..e4f31a184 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token): except (FileNotFoundError, ConnectionError): pass + # gather extra args from the config + load_ds_kwargs = {} + if config_dataset.split: + load_ds_kwargs["split"] = config_dataset.split + else: + load_ds_kwargs["split"] = None + # prefer local dataset, even if hub exists local_path = Path(config_dataset.path) if local_path.exists(): @@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.data_files, streaming=False, - split=None, + **load_ds_kwargs, ) else: try: @@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token): config_dataset.path, name=config_dataset.name, streaming=False, - split=None, + **load_ds_kwargs, ) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) @@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, + **load_ds_kwargs, ) else: raise ValueError( "unhandled dataset load: local path exists, but is neither a directory or a file" ) elif ds_from_hub: - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split ds = load_dataset( config_dataset.path, name=config_dataset.name, @@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=config_dataset.path, streaming=False, - split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, ) else: if isinstance(config_dataset.data_files, str): @@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token): name=config_dataset.name, data_files=fp, streaming=False, - split=None, + **load_ds_kwargs, ) if not ds: raise ValueError("unhandled dataset load")