diff --git a/deepspeed_configs/zero1_torch_compile.json b/deepspeed_configs/zero1_torch_compile.json new file mode 100644 index 000000000..b88451392 --- /dev/null +++ b/deepspeed_configs/zero1_torch_compile.json @@ -0,0 +1,27 @@ +{ + "zero_optimization": { + "stage": 1, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "compile": { + "disable": false, + "backend": "inductor" + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/requirements.txt b/requirements.txt index 61e1a9f90..283b5cc2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 -axolotl-contribs-lgpl==0.0.1b2 +axolotl-contribs-lgpl==0.0.2 diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 6f883a2ac..14803e43b 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs): @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( "--accelerate/--no-accelerate", - default=True, + default=False, help="Use accelerate launch for multi-GPU inference", ) @click.option( @@ -124,7 +124,7 @@ def inference( if lora_model_dir: kwargs["lora_model_dir"] = lora_model_dir if base_model: - kwargs["output_dir"] = base_model + kwargs["base_model"] = base_model if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee19536..e81740399 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -56,6 +56,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, + GCCallback, GPUStatsCallback, LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, @@ -67,7 +68,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1452,6 +1453,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + if self.cfg.gc_steps: + callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) callbacks.append(SaveModelCallback()) return callbacks @@ -1831,8 +1834,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template, + training_arguments_kwargs["chat_template"] = get_chat_template_from_config( + cfg=self.cfg, tokenizer=self.tokenizer, ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b09..a74ecc2ec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,5 +1,6 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import inspect import os import signal import sys @@ -126,7 +127,20 @@ def train( ) if cfg.fix_untrained_tokens: - fix_untrained_tokens(model, tokenizer, train_dataset) + # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + sig = inspect.signature(fix_untrained_tokens) + # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + if "token_ids_to_fix" in sig.parameters and isinstance( + cfg.fix_untrained_tokens, list + ): + fix_untrained_tokens( + model, + tokenizer, + train_dataset, + token_ids_to_fix=cfg.fix_untrained_tokens, + ) + else: + fix_untrained_tokens(model, tokenizer, train_dataset) if cfg.local_rank == 0: model.save_pretrained( str(Path(cfg.output_dir)), safe_serialization=safe_serialization diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 641c9b162..f1b459b6b 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import gc import logging import math import os @@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback): ): control.should_save = True return control + + +class GCCallback(TrainerCallback): + """Callback to garbage collect torch cache""" + + def __init__(self, gc_steps=None): + self.gc_steps = gc_steps + + def on_step_end( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + if state.global_step % self.gc_steps == 0: + torch.cuda.empty_cache() + gc.collect() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 864f639c4..89209c66f 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -666,6 +666,8 @@ class AxolotlInputConfig( loss_watchdog_threshold: Optional[float] = None loss_watchdog_patience: Optional[int] = None + gc_steps: Optional[int] = None + bf16: Optional[Union[Literal["auto"], bool]] = "auto" fp16: Optional[bool] = None bfloat16: Optional[bool] = None # for non-AMP cases @@ -792,7 +794,7 @@ class AxolotlInputConfig( chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None - fix_untrained_tokens: Optional[bool] = None + fix_untrained_tokens: Optional[Union[int, List[int]]] = None # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 16f38218c..f493db70e 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -28,8 +28,10 @@ def encode_pretraining( ) # Convert to PyTorch tensors input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] new_input_ids = [] + new_labels = [] new_attention_mask = [] # Append EOS and PAD tokens to input_ids, and correct attention_mask for i, _ in enumerate(input_ids): @@ -40,22 +42,34 @@ def encode_pretraining( ), dim=0, ) + targets[i] = torch.cat( + ( + targets[i], + torch.tensor([tokenizer.eos_token_id, -100]), + ), + dim=0, + ) attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) # Concatenate tokens so that their lengths are less than max_tokens buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) - for ids, mask in zip(input_ids, attention_mask): + for ids, labels, mask in zip(input_ids, targets, attention_mask): if buffer_input_ids.numel() == max_tokens: new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) elif buffer_input_ids.numel() + ids.numel() <= max_tokens: buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) else: buffer_input_ids = torch.cat( @@ -69,6 +83,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -81,11 +106,14 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) if buffer_input_ids.numel() > 0: # for any leftover tokens @@ -101,6 +129,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -113,11 +152,12 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) ret = { "input_ids": [seq.tolist() for seq in new_input_ids], - "labels": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_labels], "attention_mask": [seq.tolist() for seq in new_attention_mask], } diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 286e5f2d7..3e784ca3e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,7 +3,7 @@ import functools import logging from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union from datasets import ( Dataset, @@ -12,8 +12,6 @@ from datasets import ( load_dataset, load_from_disk, ) -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import HFValidationError from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -42,6 +40,7 @@ from axolotl.prompters import ( UnsupportedPrompter, ) from axolotl.utils.data.pretraining import wrap_pretraining_dataset +from axolotl.utils.data.shared import load_dataset_w_config from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, md5, @@ -85,6 +84,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): processor=processor, ) else: + # Load streaming dataset if pretraining_dataset is given path = cfg.pretraining_dataset split = "train" name = None @@ -116,7 +116,18 @@ def prepare_dataset(cfg, tokenizer, processor=None): ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") + + # Load eval dataset (non-streaming) if specified eval_dataset = None + if cfg.test_datasets: + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, + ) + if cfg.dataset_exact_deduplication: LOG.info("Deduplication not available for pretrained datasets") @@ -243,195 +254,9 @@ def load_tokenized_prepared_datasets( # pylint: disable=invalid-name for config_dataset in for_d_in_datasets(cfg_datasets): - ds: Optional[Union[Dataset, DatasetDict]] = None - ds_from_hub = False - ds_trust_remote_code = config_dataset.trust_remote_code - try: - # this is just a basic check to see if the path is a - # valid HF dataset that's loadable - load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=True, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=ds_trust_remote_code, - ) - ds_from_hub = True - except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): - pass - - ds_from_cloud = False - storage_options = {} - remote_file_system = None - if config_dataset.path.startswith("s3://"): - try: - import aiobotocore.session # type: ignore - import s3fs # type: ignore - except ImportError as exc: - raise ImportError( - "s3:// paths require aiobotocore and s3fs to be installed" - ) from exc - - # Takes credentials from ~/.aws/credentials for default profile - s3_session = aiobotocore.session.AioSession(profile="default") - storage_options = {"session": s3_session} - remote_file_system = s3fs.S3FileSystem(**storage_options) - elif config_dataset.path.startswith( - "gs://" - ) or config_dataset.path.startswith("gcs://"): - try: - import gcsfs # type: ignore - except ImportError as exc: - raise ImportError( - "gs:// or gcs:// paths require gcsfs to be installed" - ) from exc - - # gcsfs will use default credentials from the environment else anon - # https://gcsfs.readthedocs.io/en/latest/#credentials - storage_options = {"token": None} - remote_file_system = gcsfs.GCSFileSystem(**storage_options) - # TODO: Figure out how to get auth creds passed - # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): - # try: - # import adlfs - # except ImportError as exc: - # raise ImportError( - # "adl:// or abfs:// paths require adlfs to be installed" - # ) from exc - - # # Gen 1 - # storage_options = { - # "tenant_id": TENANT_ID, - # "client_id": CLIENT_ID, - # "client_secret": CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": ACCOUNT_NAME, - # "account_key": ACCOUNT_KEY, - # } - - # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) - try: - if remote_file_system and remote_file_system.exists( - config_dataset.path - ): - ds_from_cloud = True - except (FileNotFoundError, ConnectionError): - pass - - # prefer local dataset, even if hub exists - local_path = Path(config_dataset.path) - if local_path.exists(): - if local_path.is_dir(): - if config_dataset.data_files: - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=False, - split=None, - ) - else: - try: - ds = load_from_disk(config_dataset.path) - except FileNotFoundError: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - split=None, - ) - elif local_path.is_file(): - ds_type = get_ds_type(config_dataset) - - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - ) - else: - raise ValueError( - "unhandled dataset load: local path exists, but is neither a directory or a file" - ) - elif ds_from_hub: - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - data_files=config_dataset.data_files, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif ds_from_cloud and remote_file_system: - if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) - elif remote_file_system.isfile(config_dataset.path): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - ) - elif config_dataset.path.startswith("https://"): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - ) - else: - if isinstance(config_dataset.data_files, str): - fp = hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=config_dataset.data_files, - revision=config_dataset.revision, - ) - elif isinstance(config_dataset.data_files, list): - fp = [] - for file in config_dataset.data_files: - fp.append( - hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=file, - revision=config_dataset.revision, - ) - ) - else: - raise ValueError( - "data_files must be either a string or list of strings" - ) - ds = load_dataset( - "json", - name=config_dataset.name, - data_files=fp, - streaming=False, - split=None, - ) - if not ds: - raise ValueError("unhandled dataset load") + ds: Union[Dataset, DatasetDict] = load_dataset_w_config( + config_dataset, use_auth_token + ) d_base_type = d_prompt_style = None d_type = config_dataset.type @@ -501,24 +326,6 @@ def load_tokenized_prepared_datasets( return dataset, prompters -def get_ds_type(config_dataset: DictDefault): - """ - Get the dataset type from the path if it's not specified - """ - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" - return ds_type - - def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py new file mode 100644 index 000000000..d14496d96 --- /dev/null +++ b/src/axolotl/utils/data/shared.py @@ -0,0 +1,222 @@ +""" +dataset loading shared utils +""" +from pathlib import Path +from typing import Optional, Union + +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import HFValidationError + +from axolotl.utils.dict import DictDefault + + +def get_ds_type(config_dataset: DictDefault): + """ + Get the dataset type from the path if it's not specified + """ + ds_type = "json" + if config_dataset.ds_type: + ds_type = config_dataset.ds_type + elif ".parquet" in config_dataset.path: + ds_type = "parquet" + elif ".arrow" in config_dataset.path: + ds_type = "arrow" + elif ".csv" in config_dataset.path: + ds_type = "csv" + elif ".txt" in config_dataset.path: + ds_type = "text" + return ds_type + + +def load_dataset_w_config(config_dataset, auth_token): + # pylint: disable=invalid-name + ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name + ds_from_hub = False + ds_trust_remote_code = config_dataset.trust_remote_code + try: + # this is just a basic check to see if the path is a + # valid HF dataset that's loadable + load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=True, + token=auth_token, + revision=config_dataset.revision, + trust_remote_code=ds_trust_remote_code, + ) + ds_from_hub = True + except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): + pass + + ds_from_cloud = False + storage_options = {} + remote_file_system = None + if config_dataset.path.startswith("s3://"): + try: + import aiobotocore.session # type: ignore + import s3fs # type: ignore + except ImportError as exc: + raise ImportError( + "s3:// paths require aiobotocore and s3fs to be installed" + ) from exc + + # Takes credentials from ~/.aws/credentials for default profile + s3_session = aiobotocore.session.AioSession(profile="default") + storage_options = {"session": s3_session} + remote_file_system = s3fs.S3FileSystem(**storage_options) + elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith( + "gcs://" + ): + try: + import gcsfs # type: ignore + except ImportError as exc: + raise ImportError( + "gs:// or gcs:// paths require gcsfs to be installed" + ) from exc + + # gcsfs will use default credentials from the environment else anon + # https://gcsfs.readthedocs.io/en/latest/#credentials + storage_options = {"token": None} + remote_file_system = gcsfs.GCSFileSystem(**storage_options) + # TODO: Figure out how to get auth creds passed + # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): + # try: + # import adlfs + # except ImportError as exc: + # raise ImportError( + # "adl:// or abfs:// paths require adlfs to be installed" + # ) from exc + + # # Gen 1 + # storage_options = { + # "tenant_id": TENANT_ID, + # "client_id": CLIENT_ID, + # "client_secret": CLIENT_SECRET, + # } + # # Gen 2 + # storage_options = { + # "account_name": ACCOUNT_NAME, + # "account_key": ACCOUNT_KEY, + # } + + # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) + try: + if remote_file_system and remote_file_system.exists(config_dataset.path): + ds_from_cloud = True + except (FileNotFoundError, ConnectionError): + pass + + # prefer local dataset, even if hub exists + local_path = Path(config_dataset.path) + if local_path.exists(): + if local_path.is_dir(): + if config_dataset.data_files: + ds_type = get_ds_type(config_dataset) + ds = load_dataset( # pylint: disable=invalid-name + ds_type, + name=config_dataset.name, + data_files=config_dataset.data_files, + streaming=False, + split=None, + ) + else: + try: + ds = load_from_disk( + config_dataset.path + ) # pylint: disable=invalid-name + except FileNotFoundError: + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + split=None, + ) + elif local_path.is_file(): + ds_type = get_ds_type(config_dataset) + + ds = load_dataset( # pylint: disable=invalid-name + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + ) + else: + raise ValueError( + "unhandled dataset load: local path exists, but is neither a directory or a file" + ) + elif ds_from_hub: + load_ds_kwargs = {} + if config_dataset.split: + load_ds_kwargs["split"] = config_dataset.split + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + data_files=config_dataset.data_files, + token=auth_token, + revision=config_dataset.revision, + trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, + ) + elif ds_from_cloud and remote_file_system: + if remote_file_system.isdir(config_dataset.path): + ds = load_from_disk( + config_dataset.path, + storage_options=storage_options, + ) + elif remote_file_system.isfile(config_dataset.path): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, + ) + elif config_dataset.path.startswith("https://"): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, + ) + else: + if isinstance(config_dataset.data_files, str): + fp = hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=config_dataset.data_files, + revision=config_dataset.revision, + ) + elif isinstance(config_dataset.data_files, list): + fp = [] + for file in config_dataset.data_files: + fp.append( + hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=file, + revision=config_dataset.revision, + ) + ) + else: + raise ValueError("data_files must be either a string or list of strings") + ds = load_dataset( + "json", + name=config_dataset.name, + data_files=fp, + streaming=False, + split=None, + ) + if not ds: + raise ValueError("unhandled dataset load") + + return ds