diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 1bb83efd5..e657262b9 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -96,20 +96,17 @@ def load_dataset_w_config( pass ds_from_cloud = False - storage_options = {} + storage_options: dict = {} remote_file_system = None if config_dataset.path.startswith("s3://"): try: - import aiobotocore.session # type: ignore import s3fs # type: ignore except ImportError as exc: - raise ImportError( - "s3:// paths require aiobotocore and s3fs to be installed" - ) from exc + raise ImportError("s3:// paths require s3fs to be installed") from exc - # Takes credentials from ~/.aws/credentials for default profile - s3_session = aiobotocore.session.AioSession(profile="default") - storage_options = {"session": s3_session} + # Reads env, credentials from ~/.aws/credentials, or IAM metadata provider + # https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials + storage_options = {"anon": False} remote_file_system = s3fs.S3FileSystem(**storage_options) elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith( "gcs://" @@ -125,28 +122,44 @@ def load_dataset_w_config( # https://gcsfs.readthedocs.io/en/latest/#credentials storage_options = {"token": None} remote_file_system = gcsfs.GCSFileSystem(**storage_options) - # TODO: Figure out how to get auth creds passed - # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): - # try: - # import adlfs - # except ImportError as exc: - # raise ImportError( - # "adl:// or abfs:// paths require adlfs to be installed" - # ) from exc + elif ( + config_dataset.path.startswith("adl://") + or config_dataset.path.startswith("abfs://") + or config_dataset.path.startswith("az://") + ): + try: + import adlfs + except ImportError as exc: + raise ImportError( + "adl:// or abfs:// paths require adlfs to be installed" + ) from exc - # # Gen 1 - # storage_options = { - # "tenant_id": TENANT_ID, - # "client_id": CLIENT_ID, - # "client_secret": CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": ACCOUNT_NAME, - # "account_key": ACCOUNT_KEY, - # } + # # Ensure you have the following environment variables set: + # # Gen 1 + # storage_options = { + # "tenant_id": AZURE_STORAGE_TENANT_ID, + # "client_id": AZURE_STORAGE_CLIENT_ID, + # "client_secret": AZURE_STORAGE_CLIENT_SECRET, + # } + # # Gen 2 + # storage_options = { + # "account_name": AZURE_STORAGE_ACCOUNT_NAME, + # "account_key": AZURE_STORAGE_ACCOUNT_KEY, + # } + + # Reads env + # https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials + storage_options = {"anon": False} + remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) + elif config_dataset.path.startswith("oci://"): + try: + import ocifs + except ImportError as exc: + raise ImportError("oci:// paths require ocifs to be installed") from exc + + # https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables + remote_file_system = ocifs.OCIFileSystem(**storage_options) - # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) try: if remote_file_system and remote_file_system.exists(config_dataset.path): ds_from_cloud = True