Skip to content

Commit

Permalink
Merge pull request #1112 from activeloopai/fix/num_workers_0_bug
Browse files Browse the repository at this point in the history
Fix/num workers 0 bug
  • Loading branch information
davidbuniat authored Aug 11, 2021
2 parents e34d67c + 6a8d1eb commit 4e84e1c
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"ingest_kaggle",
]

__version__ = "2.0.5"
__version__ = "2.0.6"
__encoded_version__ = np.array(__version__)

hub_reporter.tags.append(f"version:{__version__}")
Expand Down
8 changes: 5 additions & 3 deletions hub/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from click.testing import CliRunner
from hub.core.storage.memory import MemoryProvider
from hub.util.remove_cache import remove_memory_cache
from hub.tests.common import parametrize_num_workers
from hub.tests.dataset_fixtures import enabled_datasets
from hub.util.exceptions import InvalidOutputDatasetError

Expand Down Expand Up @@ -71,7 +72,8 @@ def test_single_transform_hub_dataset(ds):


@enabled_datasets
def test_single_transform_hub_dataset_htypes(ds):
@parametrize_num_workers
def test_single_transform_hub_dataset_htypes(ds, num_workers):
with CliRunner().isolated_filesystem():
with hub.dataset("./test/transform_hub_in_htypes") as data_in:
data_in.create_tensor("image", htype="image", sample_compression="png")
Expand All @@ -83,7 +85,7 @@ def test_single_transform_hub_dataset_htypes(ds):
ds_out = ds
ds_out.create_tensor("image")
ds_out.create_tensor("label")
fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=5)
fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=num_workers)
assert len(ds_out) == 99
for index in range(1, 100):
np.testing.assert_array_equal(
Expand All @@ -104,7 +106,7 @@ def test_chain_transform_list_small(ds):
ds_out.create_tensor("image")
ds_out.create_tensor("label")
pipeline = hub.compose([fn1(mul=5, copy=2), fn2(mul=3, copy=3)])
pipeline.eval(ls, ds_out, num_workers=5)
pipeline.eval(ls, ds_out, num_workers=3)
assert len(ds_out) == 600
for i in range(100):
for index in range(6 * i, 6 * i + 6):
Expand Down
1 change: 1 addition & 0 deletions hub/core/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run(
"""Runs the pipeline on the input data to produce output samples and stores in the dataset.
This receives arguments processed and sanitized by the Pipeline.eval method.
"""
num_workers = max(num_workers, 1)
size = math.ceil(len(data_in) / num_workers)
slices = [data_in[i * size : (i + 1) * size] for i in range(num_workers)]

Expand Down
2 changes: 1 addition & 1 deletion hub/requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pathos
humbug>=0.2.6
types-requests
types-click
tqdm
tqdm
2 changes: 1 addition & 1 deletion hub/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ numcodecs~=0.7.3
Pillow~=8.2.0
lz4~=3.1.3
zstd~=1.4.5
requests~=2.25.1
requests~=2.25.1
3 changes: 3 additions & 0 deletions hub/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
NUM_BATCHES_PARAM = "num_batches"
DTYPE_PARAM = "dtype"
CHUNK_SIZE_PARAM = "chunk_size"
NUM_WORKERS_PARAM = "num_workers"

NUM_BATCHES = (1, 5)
NUM_WORKERS = (0, 1, 2, 4)

CHUNK_SIZES = (
1 * KB,
Expand All @@ -39,6 +41,7 @@
parametrize_chunk_sizes = pytest.mark.parametrize(CHUNK_SIZE_PARAM, CHUNK_SIZES)
parametrize_dtypes = pytest.mark.parametrize(DTYPE_PARAM, DTYPES)
parametrize_num_batches = pytest.mark.parametrize(NUM_BATCHES_PARAM, NUM_BATCHES)
parametrize_num_workers = pytest.mark.parametrize(NUM_WORKERS_PARAM, NUM_WORKERS)


def current_test_name() -> str:
Expand Down
12 changes: 9 additions & 3 deletions hub/util/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,11 @@ def check_transform_data_in(data_in, scheduler: str) -> None:
f"The data_in to transform is invalid. It should support __len__ operation."
)
if isinstance(data_in, hub.core.dataset.Dataset):
base_storage = get_base_storage(data_in.storage)
if isinstance(base_storage, MemoryProvider) and scheduler != "threaded":
input_base_storage = get_base_storage(data_in.storage)
if isinstance(input_base_storage, MemoryProvider) and scheduler not in [
"serial",
"threaded",
]:
raise InvalidOutputDatasetError(
f"Transforms with data_in as a Dataset having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}."
)
Expand All @@ -191,7 +194,10 @@ def check_transform_ds_out(ds_out: hub.core.dataset.Dataset, scheduler: str) ->
)

output_base_storage = get_base_storage(ds_out.storage)
if isinstance(output_base_storage, MemoryProvider) and scheduler != "threaded":
if isinstance(output_base_storage, MemoryProvider) and scheduler not in [
"serial",
"threaded",
]:
raise InvalidOutputDatasetError(
f"Transforms with ds_out having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}."
)

0 comments on commit 4e84e1c

Please sign in to comment.