Skip to content

Commit

Permalink
Merge pull request #481 from activeloopai/fixes/to_tensorflow
Browse files Browse the repository at this point in the history
Adds chunk optimization to to_tensorflow
  • Loading branch information
AbhinavTuli authored Jan 23, 2021
2 parents c4ba207 + f3a24c5 commit 095b695
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 113 deletions.
28 changes: 24 additions & 4 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,35 @@ def to_tensorflow(self, indexes=None):
num_samples: int, optional
The number of samples required of the dataset that needs to be converted
"""
if "tensorflow" not in sys.modules:
raise ModuleNotInstalledException("tensorflow")
else:
try:
import tensorflow as tf

global tf
except ModuleNotFoundError:
raise ModuleNotInstalledException("tensorflow")

indexes = indexes or self.indexes
indexes = [indexes] if isinstance(indexes, int) else indexes
_samples_in_chunks = {
key: (None in value.shape) and 1 or value.chunks[0]
for key, value in self._tensors.items()
}
_active_chunks = {}
_active_chunks_range = {}

def _get_active_item(key, index):
active_range = _active_chunks_range.get(key)
samples_per_chunk = _samples_in_chunks[key]
if active_range is None or index not in active_range:
active_range_start = index - index % samples_per_chunk
active_range = range(
active_range_start, active_range_start + samples_per_chunk
)
_active_chunks_range[key] = active_range
_active_chunks[key] = self._tensors[key][
active_range.start : active_range.stop
]
return _active_chunks[key][index % samples_per_chunk]

def tf_gen():
for index in indexes:
Expand All @@ -590,7 +610,7 @@ def tf_gen():
else:
cur[split_key[i]] = {}
cur = cur[split_key[i]]
cur[split_key[-1]] = self._tensors[key][index]
cur[split_key[-1]] = _get_active_item(key, index)
yield (d)

def dict_to_tf(my_dtype):
Expand Down
7 changes: 4 additions & 3 deletions hub/api/sharded_datasetview.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def shape(self):
def __len__(self):
return self.num_samples

def __str__(self):
return f"ShardedDatasetView(shape={str(self.shape)})"

def __repr__(self):
return self.__str__()

Expand Down Expand Up @@ -53,9 +56,7 @@ def slicing(self, slice_):
slice_ = list(slice_)
if not isinstance(slice_[0], int):
# TODO add advanced slicing options
raise AdvancedSlicingNotSupported(
"No slicing since there is no currently cross sharded dataset support"
)
raise AdvancedSlicingNotSupported()

shard_id, offset = self.identify_shard(slice_[0])
slice_[0] = slice_[0] - offset
Expand Down
13 changes: 13 additions & 0 deletions hub/api/tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def test_from_tensorflow():
assert res_ds["b"].numpy().tolist() == [5, 6]


@pytest.mark.skipif(not tensorflow_loaded(), reason="requires tensorflow to be loaded")
def test_to_tensorflow():
schema = {"abc": Tensor((100, 100, 3)), "int": "uint32"}
ds = hub.Dataset("./data/test_to_tf", shape=(10,), schema=schema)
for i in range(10):
ds["abc", i] = i * np.ones((100, 100, 3))
ds["int", i] = i
tds = ds.to_tensorflow()
for i, item in enumerate(tds):
assert (item["abc"].numpy() == i * np.ones((100, 100, 3))).all()
assert item["int"] == i


@pytest.mark.skipif(not tensorflow_loaded(), reason="requires tensorflow to be loaded")
def test_to_from_tensorflow():
my_schema = {
Expand Down
15 changes: 14 additions & 1 deletion hub/api/tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from hub.schema.features import SchemaDict
from hub.exceptions import AdvancedSlicingNotSupported
from hub.api.sharded_datasetview import ShardedDatasetView
from hub import Dataset
import pytest


def test_sharded_dataset():
Expand All @@ -13,9 +16,19 @@ def test_sharded_dataset():

ds[0]["first"] = 2.3
assert ds[0]["second"].numpy() != 2.3
print(ds[30]["first"].numpy())
assert ds[30]["first"].numpy() == 0
assert len(ds) == 40
assert ds.shape == (40,)
assert type(ds.schema) == SchemaDict
assert ds.__repr__() == "ShardedDatasetView(shape=(40,))"
with pytest.raises(AdvancedSlicingNotSupported):
ds[5:8]
ds[4, "first"] = 3
for _ in ds:
pass

ds2 = ShardedDatasetView([])
assert ds2.identify_shard(5) == (0, 0)


if __name__ == "__main__":
Expand Down
105 changes: 0 additions & 105 deletions hub/api/tests/tfds_meta/mnist/dataset_info.json

This file was deleted.

0 comments on commit 095b695

Please sign in to comment.