diff --git a/benchmarks/bench_png_compression.py b/benchmarks/bench_png_compression.py index ae87e44e42..3980b3f497 100644 --- a/benchmarks/bench_png_compression.py +++ b/benchmarks/bench_png_compression.py @@ -8,7 +8,7 @@ img_path = "./benchmarks/sample.png" -count = 12 +count = 20 def bench_pil_compression(img_path=img_path, count=count): diff --git a/hub/defaults.py b/hub/defaults.py index eac27f805c..0f86ccc264 100644 --- a/hub/defaults.py +++ b/hub/defaults.py @@ -1,2 +1,3 @@ CHUNK_DEFAULT_SIZE = 2 ** 24 OBJECT_CHUNK = 128 +DEFAULT_COMPRESSOR = "default" \ No newline at end of file diff --git a/hub/store/dynamic_tensor.py b/hub/store/dynamic_tensor.py index 9834e86472..731fc6661c 100644 --- a/hub/store/dynamic_tensor.py +++ b/hub/store/dynamic_tensor.py @@ -10,6 +10,7 @@ from hub.store.nested_store import NestedStore from hub.store.shape_detector import ShapeDetector +from hub.defaults import DEFAULT_COMPRESSOR from hub.exceptions import ( DynamicTensorNotFoundException, @@ -42,7 +43,7 @@ def __init__( max_shape=None, dtype="float64", chunks=None, - compressor="default", + compressor=DEFAULT_COMPRESSOR, ): """Constructor Parameters @@ -64,7 +65,9 @@ def __init__( """ if not (shape is None): # otherwise shape detector fails - shapeDt = ShapeDetector(shape, max_shape, chunks, dtype) + shapeDt = ShapeDetector( + shape, max_shape, chunks, dtype, compressor=compressor + ) shape = shapeDt.shape max_shape = shapeDt.max_shape chunks = shapeDt.chunks diff --git a/hub/store/shape_detector.py b/hub/store/shape_detector.py index ec180c3b17..e0de291e6d 100644 --- a/hub/store/shape_detector.py +++ b/hub/store/shape_detector.py @@ -1,8 +1,9 @@ +from hub.numcodecs import PngCodec import math import numpy as np -from hub.defaults import CHUNK_DEFAULT_SIZE, OBJECT_CHUNK +from hub.defaults import CHUNK_DEFAULT_SIZE, OBJECT_CHUNK, DEFAULT_COMPRESSOR from hub.exceptions import HubException @@ -17,19 +18,27 @@ def __init__( dtype="float64", chunksize=CHUNK_DEFAULT_SIZE, object_chunking=OBJECT_CHUNK, + compressor=DEFAULT_COMPRESSOR, ): self._int32max = np.iinfo(np.dtype("int32")).max self._dtype = dtype = np.dtype(dtype) - self._chunksize = chunksize self._object_chunking = object_chunking + self._compressor = compressor + self._chunksize = chunksize = self._get_chunksize(chunksize, compressor) self._shape = shape = self._get_shape(shape) self._max_shape = max_shape = self._get_max_shape(shape, max_shape) self._chunks = chunks = self._get_chunks( shape, max_shape, chunks, dtype, chunksize ) + def _get_chunksize(self, chunksize, compressor): + if isinstance(compressor, PngCodec): + return int(math.ceil(0.25 * chunksize)) + else: + return chunksize + def _get_shape(self, shape): assert shape is not None shape = (shape,) if isinstance(shape, int) else tuple(shape) diff --git a/hub/store/tests/test_shape_detector.py b/hub/store/tests/test_shape_detector.py index 75d3e1a14a..6b38e15956 100644 --- a/hub/store/tests/test_shape_detector.py +++ b/hub/store/tests/test_shape_detector.py @@ -8,6 +8,12 @@ def test_shape_detector(): assert s.chunks[1:] == (10, 10) +def test_shape_detector_2(): + s = ShapeDetector((10, 10, 10), 10, compressor="png") + assert str(s.dtype) == "float64" + assert s.chunks[1:] == (10, 10) + + def test_shape_detector_wrong_shape(): try: ShapeDetector((10, 10, 10), (10, 10, 20))