Skip to content

Commit

Permalink
Merge pull request #1101 from activeloopai/small-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
thisiseshan authored Aug 9, 2021
2 parents 78a2f3d + 894b544 commit c88c774
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 45 deletions.
22 changes: 13 additions & 9 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def like(
def ingest(
src: str,
dest: str,
dest_creds: dict,
compression: str,
images_compression: str,
dest_creds: dict = None,
overwrite: bool = False,
**dataset_kwargs,
) -> Dataset:
Expand Down Expand Up @@ -314,8 +314,8 @@ def ingest(
- an s3 path of the form s3://bucketname/path/to/dataset. Credentials are required in either the environment or passed to the creds argument.
- a local file system path of the form ./path/to/dataset or ~/path/to/dataset or path/to/dataset.
- a memory path of the form mem://path/to/dataset which doesn't save the dataset but keeps it in memory instead. Should be used only for testing as it does not persist.
images_compression (str): For image classification datasets, this compression will be used for the `images` tensor.
dest_creds (dict): A dictionary containing credentials used to access the destination path of the dataset.
compression (str): Compression type of dataset.
overwrite (bool): WARNING: If set to True this overwrites the dataset if it already exists. This can NOT be undone! Defaults to False.
**dataset_kwargs: Any arguments passed here will be forwarded to the dataset creator function.
Expand All @@ -340,7 +340,7 @@ def ingest(

# TODO: auto detect compression
unstructured.structure(
ds, image_tensor_args={"sample_compression": compression} # type: ignore
ds, image_tensor_args={"sample_compression": images_compression} # type: ignore
)

return ds # type: ignore
Expand All @@ -351,8 +351,9 @@ def ingest_kaggle(
tag: str,
src: str,
dest: str,
dest_creds: dict,
compression: str,
images_compression: str,
dest_creds: dict = None,
kaggle_credentials: dict = None,
overwrite: bool = False,
**dataset_kwargs,
) -> Dataset:
Expand All @@ -369,8 +370,9 @@ def ingest_kaggle(
- an s3 path of the form s3://bucketname/path/to/dataset. Credentials are required in either the environment or passed to the creds argument.
- a local file system path of the form ./path/to/dataset or ~/path/to/dataset or path/to/dataset.
- a memory path of the form mem://path/to/dataset which doesn't save the dataset but keeps it in memory instead. Should be used only for testing as it does not persist.
images_compression (str): For image classification datasets, this compression will be used for the `images` tensor.
dest_creds (dict): A dictionary containing credentials used to access the destination path of the dataset.
compression (str): Compression type of dataset.
kaggle_credentials (dict): A dictionary containing kaggle credentials {"username":"YOUR_USERNAME", "key": "YOUR_KEY"}. If None, environment variables/the kaggle.json file will be used if available.
overwrite (bool): WARNING: If set to True this overwrites the dataset if it already exists. This can NOT be undone! Defaults to False.
**dataset_kwargs: Any arguments passed here will be forwarded to the dataset creator function.
Expand All @@ -384,13 +386,15 @@ def ingest_kaggle(
if os.path.samefile(src, dest):
raise SamePathException(src)

download_kaggle_dataset(tag, local_path=src)
download_kaggle_dataset(
tag, local_path=src, kaggle_credentials=kaggle_credentials
)

ds = hub.ingest(
src=src,
dest=dest,
dest_creds=dest_creds,
compression=compression,
images_compression=images_compression,
overwrite=overwrite,
**dataset_kwargs,
)
Expand Down
57 changes: 44 additions & 13 deletions hub/auto/tests/_test_kaggle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from hub.auto.unstructured import kaggle
from hub.api.dataset import Dataset
from hub.util.exceptions import KaggleDatasetAlreadyDownloadedError, SamePathException
from hub.util.exceptions import (
KaggleDatasetAlreadyDownloadedError,
SamePathException,
KaggleMissingCredentialsError,
ExternalCommandError,
)
from hub.tests.common import get_dummy_data_path
import pytest
import os
Expand All @@ -12,13 +18,12 @@ def test_ingestion_simple(local_ds: Dataset):
tag="andradaolteanu/birdcall-recognition-data",
src=kaggle_path,
dest=local_ds.path,
dest_creds={},
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

assert list(ds.tensors.keys()) == ["images", "labels"]
assert ds["labels"].numpy().shape == (10,)
assert ds["labels"].numpy().shape == (10, 1)


def test_ingestion_sets(local_ds: Dataset):
Expand All @@ -28,8 +33,7 @@ def test_ingestion_sets(local_ds: Dataset):
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
dest_creds={},
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

Expand Down Expand Up @@ -57,25 +61,52 @@ def test_kaggle_exception(local_ds: Dataset):
tag="thisiseshan/bird-classes",
src=dummy_path,
dest=dummy_path,
dest_creds=None,
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

with pytest.raises(KaggleDatasetAlreadyDownloadedError):
with pytest.raises(KaggleMissingCredentialsError):
hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
dest_creds={},
compression="jpeg",
images_compression="jpeg",
kaggle_credentials={"not_username": "not_username"},
overwrite=False,
)

with pytest.raises(KaggleMissingCredentialsError):
hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": "thisiseshan", "not_key": "not_key"},
overwrite=False,
)

with pytest.raises(ExternalCommandError):
hub.ingest_kaggle(
tag="thisiseshan/invalid-dataset",
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
overwrite=False,
)

hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
overwrite=False,
)

with pytest.raises(KaggleDatasetAlreadyDownloadedError):
hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
dest_creds={},
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)
21 changes: 9 additions & 12 deletions hub/auto/tests/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@ def test_ingestion_simple(memory_ds: Dataset):
hub.ingest(
src="tests_auto/invalid_path",
dest=memory_ds.path,
dest_creds=None,
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

with pytest.raises(SamePathException):
hub.ingest(
src=path, dest=path, dest_creds=None, compression="jpeg", overwrite=False
)
hub.ingest(src=path, dest=path, images_compression="jpeg", overwrite=False)

ds = hub.ingest(
src=path,
dest=memory_ds.path,
dest_creds=None,
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

Expand All @@ -41,8 +37,7 @@ def test_image_classification_sets(memory_ds: Dataset):
ds = hub.ingest(
src=path,
dest=memory_ds.path,
dest_creds=None,
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

Expand All @@ -67,12 +62,14 @@ def test_ingestion_exception(memory_ds: Dataset):
hub.ingest(
src="tests_auto/invalid_path",
dest=memory_ds.path,
dest_creds=None,
compression="jpeg",
images_compression="jpeg",
overwrite=False,
)

with pytest.raises(SamePathException):
hub.ingest(
src=path, dest=path, dest_creds=None, compression="jpeg", overwrite=False
src=path,
dest=path,
images_compression="jpeg",
overwrite=False,
)
24 changes: 13 additions & 11 deletions hub/auto/unstructured/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ def _exec_command(command):


def _set_environment_credentials_if_none(kaggle_credentials: dict = None):
if kaggle_credentials is None:
kaggle_credentials = {}
if kaggle_credentials is not None:
username = kaggle_credentials.get("username", None)
if not username:
raise KaggleMissingCredentialsError(ENV_KAGGLE_USERNAME)
os.environ[ENV_KAGGLE_USERNAME] = username
key = kaggle_credentials.get("key", None)
if not key:
raise KaggleMissingCredentialsError(ENV_KAGGLE_KEY)
os.environ[ENV_KAGGLE_KEY] = key
else:
if ENV_KAGGLE_USERNAME not in os.environ:
username = kaggle_credentials.get("username", None)
if not username:
raise KaggleMissingCredentialsError(ENV_KAGGLE_USERNAME)
os.environ[ENV_KAGGLE_USERNAME] = username
raise KaggleMissingCredentialsError(ENV_KAGGLE_USERNAME)
if ENV_KAGGLE_KEY not in os.environ:
key = kaggle_credentials.get("key", None)
if not key:
raise KaggleMissingCredentialsError(ENV_KAGGLE_KEY)
os.environ[ENV_KAGGLE_KEY] = key
raise KaggleMissingCredentialsError(ENV_KAGGLE_KEY)


def download_kaggle_dataset(tag: str, local_path: str, kaggle_credentials: dict = {}):
def download_kaggle_dataset(tag: str, local_path: str, kaggle_credentials: dict = None):
"""Calls the kaggle API (https://www.kaggle.com/docs/api) to download a kaggle dataset and unzip it's contents.
Args:
Expand Down

0 comments on commit c88c774

Please sign in to comment.