Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
c7b095d77f resume from optimizer checkpoint only 2024-12-30 07:51:44 -05:00
NJordan72
7a38dbe674 fix: allow trainer builder to use custom jinja chat template (#2219)
* fix: allow trainer builder to use custom jinja chat template

* chore: use get_chat_template_from_config

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>

* fix: swap imports

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-12-24 16:18:50 -05:00
Wing Lian
e0a2eb2ebd fix untrained tokens if specified explicitly from a list (#2210) 2024-12-23 09:08:28 -05:00
Wing Lian
d852d7af7a inference - don't default w accelerate, fix base model (#2216) [skip ci] 2024-12-23 07:48:41 -05:00
Wing Lian
3742deb1de add deepspeed example with torch compile enabled (#2212) [skip ci] 2024-12-22 12:11:39 -05:00
Wing Lian
2312caaa98 GC every n steps (#2209) 2024-12-21 17:38:33 -05:00
Wing Lian
307cf7c685 move the dataset loading from remote/disk to a shared function so we can re-use for RL (#2204) 2024-12-20 21:43:52 -05:00
Dan Saunders
70541145f1 adding test_datasets compat with pretraining_dataset (streaming) (#2206) [skip ci] 2024-12-20 21:43:33 -05:00
Wing Lian
42bd32a233 add outputs (symlink) to gitignore [skip ci] (#2205) 2024-12-19 20:14:43 -05:00
Dan Saunders
5b8fb5e939 remove cicd pytest xdist args (#2201)
* remove cicd pytest xdist args

* Delete outputs
2024-12-19 11:44:53 -05:00
11 changed files with 322 additions and 220 deletions

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
**/axolotl.egg-info **/axolotl.egg-info
configs configs
last_run_prepared/ last_run_prepared/
outputs
.vscode .vscode
_site/ _site/

View File

@@ -5,6 +5,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -0,0 +1,27 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"compile": {
"disable": false,
"backend": "inductor"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.1b2 axolotl-contribs-lgpl==0.0.2

View File

@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option( @click.option(
"--accelerate/--no-accelerate", "--accelerate/--no-accelerate",
default=True, default=False,
help="Use accelerate launch for multi-GPU inference", help="Use accelerate launch for multi-GPU inference",
) )
@click.option( @click.option(
@@ -124,7 +124,7 @@ def inference(
if lora_model_dir: if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir kwargs["lora_model_dir"] = lora_model_dir
if base_model: if base_model:
kwargs["output_dir"] = base_model kwargs["base_model"] = base_model
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]

View File

@@ -56,6 +56,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GCCallback,
GPUStatsCallback, GPUStatsCallback,
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
@@ -67,7 +68,7 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
@@ -423,6 +424,11 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler return self.lr_scheduler
def _load_optimizer_and_scheduler(self, checkpoint):
if not checkpoint and self.args.optimizer_checkpoint is not None:
checkpoint = self.args.optimizer_checkpoint
return super()._load_optimizer_and_scheduler(checkpoint)
class AxolotlTrainer(SchedulerMixin, Trainer): class AxolotlTrainer(SchedulerMixin, Trainer):
""" """
@@ -1452,6 +1458,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback()) callbacks.append(SaveModelCallback())
return callbacks return callbacks
@@ -1761,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
if self.cfg.optimizer_checkpoint:
training_arguments_kwargs[
"optimizer_checkpoint"
] = self.cfg.optimizer_checkpoint
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1831,8 +1843,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template: if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template( training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
self.cfg.chat_template, cfg=self.cfg,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
) )

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os import os
import signal import signal
import sys import sys
@@ -126,7 +127,20 @@ def train(
) )
if cfg.fix_untrained_tokens: if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset) # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import gc
import logging import logging
import math import math
import os import os
@@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback):
): ):
control.should_save = True control.should_save = True
return control return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""
def __init__(self, gc_steps=None):
self.gc_steps = gc_steps
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()

View File

@@ -603,6 +603,8 @@ class AxolotlInputConfig(
strict: Optional[bool] = Field(default=False) strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None auto_resume_from_checkpoints: Optional[bool] = None
optimizer_checkpoint: Optional[str] = None
resize_token_embeddings_to_32x: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False mean_resizing_embeddings: Optional[bool] = False
@@ -666,6 +668,8 @@ class AxolotlInputConfig(
loss_watchdog_threshold: Optional[float] = None loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None loss_watchdog_patience: Optional[int] = None
gc_steps: Optional[int] = None
bf16: Optional[Union[Literal["auto"], bool]] = "auto" bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases bfloat16: Optional[bool] = None # for non-AMP cases
@@ -792,7 +796,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None fix_untrained_tokens: Optional[Union[int, List[int]]] = None
# INTERNALS - document for now, generally not set externally # INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None is_preprocess: Optional[bool] = None

View File

@@ -3,7 +3,7 @@
import functools import functools
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Tuple, Union
from datasets import ( from datasets import (
Dataset, Dataset,
@@ -12,8 +12,6 @@ from datasets import (
load_dataset, load_dataset,
load_from_disk, load_from_disk,
) )
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -42,6 +40,7 @@ from axolotl.prompters import (
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import ( from axolotl.utils.data.utils import (
deduplicate_and_log_datasets, deduplicate_and_log_datasets,
md5, md5,
@@ -85,6 +84,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
processor=processor, processor=processor,
) )
else: else:
# Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
split = "train" split = "train"
name = None name = None
@@ -116,7 +116,18 @@ def prepare_dataset(cfg, tokenizer, processor=None):
) )
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
# Load eval dataset (non-streaming) if specified
eval_dataset = None eval_dataset = None
if cfg.test_datasets:
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
)
if cfg.dataset_exact_deduplication: if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets") LOG.info("Deduplication not available for pretrained datasets")
@@ -243,195 +254,9 @@ def load_tokenized_prepared_datasets(
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets): for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
ds_from_hub = False config_dataset, use_auth_token
ds_trust_remote_code = config_dataset.trust_remote_code )
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = config_dataset.type d_type = config_dataset.type
@@ -501,24 +326,6 @@ def load_tokenized_prepared_datasets(
return dataset, prompters return dataset, prompters
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,

View File

@@ -0,0 +1,222 @@
"""
dataset loading shared utils
"""
from pathlib import Path
from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HFValidationError
from axolotl.utils.dict import DictDefault
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_dataset_w_config(config_dataset, auth_token):
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
return ds