Compare commits
10 Commits
djsaunde-p
...
optimizer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b095d77f | ||
|
|
7a38dbe674 | ||
|
|
e0a2eb2ebd | ||
|
|
d852d7af7a | ||
|
|
3742deb1de | ||
|
|
2312caaa98 | ||
|
|
307cf7c685 | ||
|
|
70541145f1 | ||
|
|
42bd32a233 | ||
|
|
5b8fb5e939 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
**/axolotl.egg-info
|
||||
configs
|
||||
last_run_prepared/
|
||||
outputs
|
||||
.vscode
|
||||
_site/
|
||||
|
||||
|
||||
@@ -5,6 +5,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
27
deepspeed_configs/zero1_torch_compile.json
Normal file
27
deepspeed_configs/zero1_torch_compile.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"compile": {
|
||||
"disable": false,
|
||||
"backend": "inductor"
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.1b2
|
||||
axolotl-contribs-lgpl==0.0.2
|
||||
|
||||
@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
)
|
||||
@click.option(
|
||||
@@ -124,7 +124,7 @@ def inference(
|
||||
if lora_model_dir:
|
||||
kwargs["lora_model_dir"] = lora_model_dir
|
||||
if base_model:
|
||||
kwargs["output_dir"] = base_model
|
||||
kwargs["base_model"] = base_model
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
|
||||
@@ -56,6 +56,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
@@ -67,7 +68,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
@@ -423,6 +424,11 @@ class SchedulerMixin(Trainer):
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
if not checkpoint and self.args.optimizer_checkpoint is not None:
|
||||
checkpoint = self.args.optimizer_checkpoint
|
||||
return super()._load_optimizer_and_scheduler(checkpoint)
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
@@ -1452,6 +1458,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
@@ -1761,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
if self.cfg.optimizer_checkpoint:
|
||||
training_arguments_kwargs[
|
||||
"optimizer_checkpoint"
|
||||
] = self.cfg.optimizer_checkpoint
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
@@ -1831,8 +1843,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template,
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -126,7 +127,20 @@ def train(
|
||||
)
|
||||
|
||||
if cfg.fix_untrained_tokens:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback):
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
class GCCallback(TrainerCallback):
|
||||
"""Callback to garbage collect torch cache"""
|
||||
|
||||
def __init__(self, gc_steps=None):
|
||||
self.gc_steps = gc_steps
|
||||
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if state.global_step % self.gc_steps == 0:
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -603,6 +603,8 @@ class AxolotlInputConfig(
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
optimizer_checkpoint: Optional[str] = None
|
||||
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
|
||||
@@ -666,6 +668,8 @@ class AxolotlInputConfig(
|
||||
loss_watchdog_threshold: Optional[float] = None
|
||||
loss_watchdog_patience: Optional[int] = None
|
||||
|
||||
gc_steps: Optional[int] = None
|
||||
|
||||
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||
fp16: Optional[bool] = None
|
||||
bfloat16: Optional[bool] = None # for non-AMP cases
|
||||
@@ -792,7 +796,7 @@ class AxolotlInputConfig(
|
||||
chat_template_jinja: Optional[str] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
fix_untrained_tokens: Optional[bool] = None
|
||||
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
is_preprocess: Optional[bool] = None
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
@@ -12,8 +12,6 @@ from datasets import (
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HFValidationError
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -42,6 +40,7 @@ from axolotl.prompters import (
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.shared import load_dataset_w_config
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
md5,
|
||||
@@ -85,6 +84,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
processor=processor,
|
||||
)
|
||||
else:
|
||||
# Load streaming dataset if pretraining_dataset is given
|
||||
path = cfg.pretraining_dataset
|
||||
split = "train"
|
||||
name = None
|
||||
@@ -116,7 +116,18 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
|
||||
# Load eval dataset (non-streaming) if specified
|
||||
eval_dataset = None
|
||||
if cfg.test_datasets:
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
LOG.info("Deduplication not available for pretrained datasets")
|
||||
|
||||
@@ -243,195 +254,9 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import aiobotocore.session # type: ignore
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||
) from exc
|
||||
|
||||
# Takes credentials from ~/.aws/credentials for default profile
|
||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||
storage_options = {"session": s3_session}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith(
|
||||
"gs://"
|
||||
) or config_dataset.path.startswith("gcs://"):
|
||||
try:
|
||||
import gcsfs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
# TODO: Figure out how to get auth creds passed
|
||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||
# try:
|
||||
# import adlfs
|
||||
# except ImportError as exc:
|
||||
# raise ImportError(
|
||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||
# ) from exc
|
||||
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": TENANT_ID,
|
||||
# "client_id": CLIENT_ID,
|
||||
# "client_secret": CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": ACCOUNT_NAME,
|
||||
# "account_key": ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(
|
||||
config_dataset.path
|
||||
):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
if config_dataset.data_files:
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"data_files must be either a string or list of strings"
|
||||
)
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token
|
||||
)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = config_dataset.type
|
||||
@@ -501,24 +326,6 @@ def load_tokenized_prepared_datasets(
|
||||
return dataset, prompters
|
||||
|
||||
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
|
||||
222
src/axolotl/utils/data/shared.py
Normal file
222
src/axolotl/utils/data/shared.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
dataset loading shared utils
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HFValidationError
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
|
||||
|
||||
def load_dataset_w_config(config_dataset, auth_token):
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import aiobotocore.session # type: ignore
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||
) from exc
|
||||
|
||||
# Takes credentials from ~/.aws/credentials for default profile
|
||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||
storage_options = {"session": s3_session}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
||||
"gcs://"
|
||||
):
|
||||
try:
|
||||
import gcsfs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
# TODO: Figure out how to get auth creds passed
|
||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||
# try:
|
||||
# import adlfs
|
||||
# except ImportError as exc:
|
||||
# raise ImportError(
|
||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||
# ) from exc
|
||||
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": TENANT_ID,
|
||||
# "client_id": CLIENT_ID,
|
||||
# "client_secret": CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": ACCOUNT_NAME,
|
||||
# "account_key": ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
if config_dataset.data_files:
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ds = load_from_disk(
|
||||
config_dataset.path
|
||||
) # pylint: disable=invalid-name
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
token=auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("data_files must be either a string or list of strings")
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
|
||||
return ds
|
||||
Reference in New Issue
Block a user