diff --git a/deeplake/api/tests/test_api.py b/deeplake/api/tests/test_api.py index ba0795767b..07ab15480b 100644 --- a/deeplake/api/tests/test_api.py +++ b/deeplake/api/tests/test_api.py @@ -1159,6 +1159,7 @@ def test_compressions_list(): "png", "ppm", "sgi", + "stl", "tga", "tiff", "wav", diff --git a/deeplake/api/tests/test_mesh.py b/deeplake/api/tests/test_mesh.py index 02b410af8d..2337693a8a 100644 --- a/deeplake/api/tests/test_mesh.py +++ b/deeplake/api/tests/test_mesh.py @@ -1,7 +1,12 @@ import pytest import deeplake -from deeplake.util.exceptions import DynamicTensorNumpyError +import numpy as np +from deeplake.util.exceptions import ( + DynamicTensorNumpyError, + MeshTensorMetaMissingRequiredValue, + UnsupportedCompressionError, +) def test_mesh(local_ds, mesh_paths): @@ -31,3 +36,28 @@ def test_mesh(local_ds, mesh_paths): tensor_data = tensor.data() assert len(tensor_data) == 4 + + +def test_stl_mesh(local_ds, stl_mesh_paths): + tensor = local_ds.create_tensor("stl_mesh", htype="mesh", sample_compression="stl") + + with pytest.raises(UnsupportedCompressionError): + local_ds.create_tensor("unsupported", htype="mesh", sample_compression=None) + + with pytest.raises(MeshTensorMetaMissingRequiredValue): + local_ds.create_tensor("unsupported", htype="mesh") + + for i, (_, path) in enumerate(stl_mesh_paths.items()): + sample = deeplake.read(path) + tensor.append(sample) + tensor.append(deeplake.read(path)) + + tensor_numpy = tensor.numpy() + assert tensor_numpy.shape == (4, 12, 3, 3) + assert np.all(tensor_numpy[0] == tensor_numpy[1]) + assert np.all(tensor_numpy[1] == tensor_numpy[2]) + assert np.all(tensor_numpy[2] == tensor_numpy[3]) + + tensor_data = tensor.data() + tensor_0_data = tensor[0].data() + assert np.all(tensor_data["vertices"][0] == tensor_0_data["vertices"]) diff --git a/deeplake/compression.py b/deeplake/compression.py index 8b9139b2ec..8eba9f3de0 100644 --- a/deeplake/compression.py +++ b/deeplake/compression.py @@ -73,7 +73,7 @@ AUDIO_COMPRESSIONS = ["mp3", "flac", "wav"] NIFTI_COMPRESSIONS = ["nii", "nii.gz"] POINT_CLOUD_COMPRESSIONS = ["las"] -MESH_COMPRESSIONS = ["ply"] +MESH_COMPRESSIONS = ["ply", "stl"] READONLY_COMPRESSIONS = [ "mpo", diff --git a/deeplake/core/compression.py b/deeplake/core/compression.py index bc3296cba6..c793a312d9 100644 --- a/deeplake/core/compression.py +++ b/deeplake/core/compression.py @@ -293,7 +293,7 @@ def decompress_array( return _decompress_audio(buffer) elif compr_type == VIDEO_COMPRESSION: return _decompress_video(buffer, start_idx, end_idx, step, reverse) # type: ignore - elif compr_type in [POINT_CLOUD_COMPRESSION, MESH_COMPRESSION]: + elif compr_type in [POINT_CLOUD_COMPRESSION] or compression == "ply": return _decompress_3d_data(buffer) if compression == "apng": @@ -304,6 +304,8 @@ def decompress_array( return _decompress_nifti(buffer) if compression == "nii.gz": return _decompress_nifti(buffer, gz=True) + if compression == "stl": + return _decompress_stl(buffer) if compression is None and isinstance(buffer, memoryview) and shape is not None: assert buffer is not None assert shape is not None @@ -464,6 +466,8 @@ def verify_compressed_file( return _read_nifti_shape_and_dtype(file, gz=compression == "nii.gz") elif compression in ("las", "ply"): return _read_3d_data_shape_and_dtype(file) + elif compression == "stl": + return _read_stl_shape_and_dtype(file) else: return _fast_decompress(file) except Exception as e: @@ -490,6 +494,7 @@ def get_compression(header=None, path=None): ".ply", ".nii", ".nii.gz", + ".stl", ] path = str(path).lower().partition("?")[0].partition("#")[0].partition(";")[0] for fmt in file_formats: @@ -519,6 +524,10 @@ def get_compression(header=None, path=None): return "dcm" if header[0:4] == b"\x6e\x2b\x31\x00": return "nii" + if any( + header[: len(x)] == x for x in [b"\x73\x6F\x6C\x69", b"numpy-stl", b"solid"] + ): + return "stl" if not Image.OPEN: Image.init() for fmt in Image.OPEN: @@ -711,6 +720,11 @@ def read_meta_from_compressed_file( shape, typestr = _read_3d_data_shape_and_dtype(file) except Exception as e: raise CorruptedSampleError(compression, path) from e + elif compression == "stl": + try: + shape, typestr = _read_stl_shape_and_dtype(file) + except Exception as e: + raise CorruptedSampleError(compression, path) from e else: img = Image.open(f) if isfile else Image.open(BytesIO(f)) # type: ignore shape, typestr = Image._conv_type_shape(img) @@ -1183,6 +1197,15 @@ def _open_3d_data(file): return point_cloud +def _open_mesh_data(file: Union[bytes, memoryview, str]): + from stl import mesh + + if isinstance(file, str): + return mesh.Mesh.from_file(file) + + return mesh.Mesh.from_file("", fh=BytesIO(file)) + + def _decompress_3d_data(file: Union[bytes, memoryview, str]): point_cloud = _open_3d_data(file) return point_cloud.decompressed_3d_data @@ -1193,11 +1216,34 @@ def _read_3d_data_shape_and_dtype(file: Union[bytes, BinaryIO]): return point_cloud.shape, point_cloud.dtype +def _read_stl_shape_and_dtype(file): + mesh_data = _open_mesh_data(file) + return mesh_data.vectors.shape, mesh_data.vectors.dtype + + +def _decompress_stl(file: Union[bytes, str]): + mesh_data = _open_mesh_data(file) + return mesh_data.vectors + + def _read_3d_data_meta(file: Union[bytes, memoryview, str]): point_cloud = _open_3d_data(file) return point_cloud.meta_data +def _read_stl_data_meta(file: Union[bytes, memoryview, str]): + mesh_data = _open_mesh_data(file) + return { + "name": mesh_data.name, + "min_": mesh_data.min_, + "max_": mesh_data.max_, + "speedups": mesh_data.speedups, + "centroids": mesh_data.centroids, + "normals": mesh_data.normals, + "extension": "stl", + } + + def _open_nifti(file: Union[bytes, memoryview, str], gz: bool = False): try: import nibabel as nib # type: ignore diff --git a/deeplake/core/meta/tensor_meta.py b/deeplake/core/meta/tensor_meta.py index 259fce06b7..0e1dafe80a 100644 --- a/deeplake/core/meta/tensor_meta.py +++ b/deeplake/core/meta/tensor_meta.py @@ -9,6 +9,7 @@ TensorMetaInvalidHtypeOverwriteValue, TensorMetaInvalidHtypeOverwriteKey, TensorMetaMissingRequiredValue, + MeshTensorMetaMissingRequiredValue, TensorMetaMutuallyExclusiveKeysError, UnsupportedCompressionError, TensorInvalidSampleShapeError, @@ -328,7 +329,16 @@ def _validate_htype_overwrites(htype: str, htype_overwrite: dict): raise TensorMetaMissingRequiredValue( actual_htype, ["chunk_compression", "sample_compression"] # type: ignore ) - if htype in ("audio", "video", "point_cloud", "mesh", "nifti"): + if htype == "mesh": + supported_compressions = HTYPE_SUPPORTED_COMPRESSIONS.get(htype) + if sc == UNSPECIFIED: + raise MeshTensorMetaMissingRequiredValue( + actual_htype, "sample_compression", compr_list=supported_compressions # type: ignore + ) + if sc not in supported_compressions: # type: ignore + raise UnsupportedCompressionError(sc, htype=htype) + + elif htype in ("audio", "video", "point_cloud", "nifti"): if cc not in (UNSPECIFIED, None): raise UnsupportedCompressionError("Chunk compression", htype=htype) elif sc == UNSPECIFIED: diff --git a/deeplake/core/sample.py b/deeplake/core/sample.py index 7db51469d9..a6ed787887 100644 --- a/deeplake/core/sample.py +++ b/deeplake/core/sample.py @@ -9,6 +9,7 @@ _read_metadata_from_vstream, _read_audio_meta, _read_3d_data_meta, + _read_stl_data_meta, _open_nifti, HEADER_MAX_BYTES, ) @@ -275,6 +276,13 @@ def _get_point_cloud_meta(self) -> dict: info = _read_3d_data_meta(self.buffer) return info + def _get_stl_meta(self) -> dict: + if self.path and get_path_type(self.path) == "local": + info = _read_stl_data_meta(self.path) + else: + info = _read_stl_data_meta(self.buffer) + return info + @property def is_lazy(self) -> bool: return self._array is None @@ -552,6 +560,8 @@ def meta(self) -> dict: compression_type = get_compression_type(compression) if compression == "dcm": meta.update(self._get_dicom_meta()) + elif compression == "stl": + meta.update(self._get_stl_meta()) elif compression_type == NIFTI_COMPRESSION: meta.update(self._get_nifti_meta()) elif compression_type == IMAGE_COMPRESSION: diff --git a/deeplake/enterprise/dataloader.py b/deeplake/enterprise/dataloader.py index 289ea401af..94facb8a71 100644 --- a/deeplake/enterprise/dataloader.py +++ b/deeplake/enterprise/dataloader.py @@ -575,18 +575,19 @@ def tensorflow( Args: num_workers (int): Number of workers to use for transforming and processing the data. Defaults to 0. collate_fn (Callable, Optional): merges a list of samples to form a mini-batch of Tensor(s). - tensors (List[str], Optional): List of tensors to load. If None, all tensors are loaded. Defaults to ``None``. + tensors (List, Optional): List of tensors to load. If ``None``, all tensors are loaded. Defaults to ``None``. + For datasets with many tensors, its extremely important to stream only the data that is needed for training the model, in order to avoid bottlenecks associated with streaming unused data. + For example, if you have a dataset that has ``image``, ``label``, and ``metadata`` tensors, if ``tensors=["image", "label"]``, the Data Loader will only stream the ``image`` and ``label`` tensors. num_threads (int, Optional): Number of threads to use for fetching and decompressing the data. If ``None``, the number of threads is automatically determined. Defaults to ``None``. prefetch_factor (int): Number of batches to transform and collate in advance per worker. Defaults to 2. - return_index (bool): Used to idnetify where loader needs to retur sample index or not. Defaults to ``True``. + return_index (bool): If ``True``, the returned dataloader will have a key "index" that contains the index of the sample(s) in the original dataset. Default value is True. persistent_workers (bool): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. Defaults to ``False``. - decode_method (Dict[str, str], Optional): A dictionary of decode methods for each tensor. Defaults to ``None``. - + decode_method (Dict[str, str], Optional): The method for decoding the Deep Lake tensor data, the result of which is passed to the transform. Decoding occurs outside of the transform so that it can be performed in parallel and as rapidly as possible as per Deep Lake optimizations. - Supported decode methods are: - - :'numpy': Default behaviour. Returns samples as numpy arrays. - :'tobytes': Returns raw bytes of the samples. + :'numpy': Default behaviour. Returns samples as numpy arrays, the same as ds.tensor[i].numpy() + :'tobytes': Returns raw bytes of the samples the same as ds.tensor[i].tobytes() + :'data': Returns a dictionary with keys,values depending on htype, the same as ds.tensor[i].data() :'pil': Returns samples as PIL images. Especially useful when transformation use torchvision transforms, that require PIL images as input. Only supported for tensors with ``sample_compression='jpeg'`` or ``'png'``. diff --git a/deeplake/requirements/common.txt b/deeplake/requirements/common.txt index 5711ab2b68..ceb4fd23b1 100644 --- a/deeplake/requirements/common.txt +++ b/deeplake/requirements/common.txt @@ -11,7 +11,7 @@ pathos humbug>=0.3.1 tqdm lz4 -av>=8.1.0; python_version >= '3.7' or sys_platform != 'win32' +av>=8.1.0,<=12.3.0; python_version >= '3.7' or sys_platform != 'win32' pydicom IPython flask @@ -24,3 +24,4 @@ azure-cli azure-identity azure-storage-blob pydantic +numpy-stl diff --git a/deeplake/requirements/plugins.txt b/deeplake/requirements/plugins.txt index 68757a698a..ddf7cb5e51 100644 --- a/deeplake/requirements/plugins.txt +++ b/deeplake/requirements/plugins.txt @@ -10,4 +10,4 @@ mmdet==2.28.1; platform_system == "Linux" and python_version >= "3.7" mmsegmentation==0.30.0; platform_system == "Linux" and python_version >= "3.7" mmengine pandas -av \ No newline at end of file +av==12.3.0 \ No newline at end of file diff --git a/deeplake/tests/dummy_data/mesh/box_freecad_ascii.stl b/deeplake/tests/dummy_data/mesh/box_freecad_ascii.stl new file mode 100644 index 0000000000..9d0f10c7a6 --- /dev/null +++ b/deeplake/tests/dummy_data/mesh/box_freecad_ascii.stl @@ -0,0 +1,86 @@ +solid b'MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH' +facet normal 0.0 -400.0 0.0 + outer loop + vertex 0.0 -20.0 0.0 + vertex 20.0 -20.0 0.0 + vertex 20.0 -20.0 20.0 + endloop +endfacet +facet normal 0.0 -400.0 0.0 + outer loop + vertex 0.0 -20.0 0.0 + vertex 20.0 -20.0 20.0 + vertex 0.0 -20.0 20.0 + endloop +endfacet +facet normal -400.0 0.0 0.0 + outer loop + vertex 0.0 0.0 0.0 + vertex 0.0 -20.0 0.0 + vertex 0.0 -20.0 20.0 + endloop +endfacet +facet normal -400.0 0.0 0.0 + outer loop + vertex 0.0 0.0 0.0 + vertex 0.0 -20.0 20.0 + vertex 0.0 0.0 20.0 + endloop +endfacet +facet normal 0.0 400.0 0.0 + outer loop + vertex 20.0 0.0 0.0 + vertex 0.0 0.0 0.0 + vertex 0.0 0.0 20.0 + endloop +endfacet +facet normal 0.0 400.0 -0.0 + outer loop + vertex 20.0 0.0 0.0 + vertex 0.0 0.0 20.0 + vertex 20.0 0.0 20.0 + endloop +endfacet +facet normal 400.0 0.0 0.0 + outer loop + vertex 20.0 -20.0 0.0 + vertex 20.0 0.0 0.0 + vertex 20.0 0.0 20.0 + endloop +endfacet +facet normal 400.0 0.0 0.0 + outer loop + vertex 20.0 -20.0 0.0 + vertex 20.0 0.0 20.0 + vertex 20.0 -20.0 20.0 + endloop +endfacet +facet normal 0.0 0.0 -400.0 + outer loop + vertex 0.0 0.0 0.0 + vertex 20.0 -20.0 0.0 + vertex 0.0 -20.0 0.0 + endloop +endfacet +facet normal 0.0 0.0 -400.0 + outer loop + vertex 0.0 0.0 0.0 + vertex 20.0 0.0 0.0 + vertex 20.0 -20.0 0.0 + endloop +endfacet +facet normal 0.0 0.0 400.0 + outer loop + vertex 20.0 -20.0 20.0 + vertex 0.0 0.0 20.0 + vertex 0.0 -20.0 20.0 + endloop +endfacet +facet normal 0.0 0.0 400.0 + outer loop + vertex 20.0 0.0 20.0 + vertex 0.0 0.0 20.0 + vertex 20.0 -20.0 20.0 + endloop +endfacet +endsolid b'MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH-MESH' diff --git a/deeplake/tests/dummy_data/mesh/box_freecad_binary.stl b/deeplake/tests/dummy_data/mesh/box_freecad_binary.stl new file mode 100644 index 0000000000..694f7a3a35 Binary files /dev/null and b/deeplake/tests/dummy_data/mesh/box_freecad_binary.stl differ diff --git a/deeplake/tests/path_fixtures.py b/deeplake/tests/path_fixtures.py index 83e7b4850f..df44af4ea1 100644 --- a/deeplake/tests/path_fixtures.py +++ b/deeplake/tests/path_fixtures.py @@ -719,6 +719,19 @@ def mesh_paths(): return paths +@pytest.fixture +def stl_mesh_paths(): + paths = { + "ascii": "box_freecad_ascii.stl", + "bin": "box_freecad_binary.stl", + } + + parent = get_dummy_data_path("mesh") + for k in paths: + paths[k] = os.path.join(parent, paths[k]) + return paths + + @pytest.fixture def vstream_path(request): """Used with parametrize to use all video stream test datasets.""" diff --git a/deeplake/util/exceptions.py b/deeplake/util/exceptions.py index 6fd72a1b10..7486d59586 100644 --- a/deeplake/util/exceptions.py +++ b/deeplake/util/exceptions.py @@ -441,6 +441,19 @@ def __init__(self, key: str, value: Any, explanation: str = ""): ) +class MeshTensorMetaMissingRequiredValue(MetaError): + def __init__(self, htype: str, key: Union[str, List[str]], compr_list: List[str]): + extra = "" + if key == "sample_compression": + extra = f"Available compressors: {compr_list}" + + if isinstance(key, list): + message = f"Htype '{htype}' requires you to specify either one of {key} inside the `create_tensor` method call. {extra}" + else: + message = f"Htype '{htype}' requires you to specify '{key}' inside the `create_tensor` method call. {extra}" + super().__init__(message) + + class TensorMetaMissingRequiredValue(MetaError): def __init__(self, htype: str, key: Union[str, List[str]]): extra = "" diff --git a/deeplake/util/object_3d/mesh.py b/deeplake/util/object_3d/mesh.py index 8cf1212df2..c0eb198858 100644 --- a/deeplake/util/object_3d/mesh.py +++ b/deeplake/util/object_3d/mesh.py @@ -1,6 +1,7 @@ import os import re import sys +import numpy as np from deeplake.util import exceptions from deeplake.util.exceptions import DynamicTensorNumpyError # type: ignore @@ -109,15 +110,49 @@ def dtype(self): def parse_mesh_to_dict(full_arr, sample_info): - # we assume that the format of files that we append is the same - fmt = sample_info[0]["fmt"] - ext = sample_info[0]["extension"] - parser = mesh_parser.get_mesh_parser(ext)(fmt, full_arr, sample_info) + if not sample_info: + return {"value": full_arr, "sample_info": {}} + + sample_info = ( + sample_info + if isinstance(sample_info, list) or isinstance(sample_info, np.ndarray) + else [sample_info] + ) + first_info = sample_info[0] + + if first_info["extension"] == "stl": + print("sample_info", sample_info) + centroids = ( + first_info.pop("centroids") + if len(sample_info) == 1 + else [info.pop("centroids") for info in sample_info] + ) + normals = ( + first_info.pop("normals") + if len(sample_info) == 1 + else [info.pop("normals") for info in sample_info] + ) + return { + "vertices": full_arr, + "centroids": centroids, + "normals": normals, + "sample_info": first_info if len(sample_info) == 1 else sample_info, + } + + parser = mesh_parser.get_mesh_parser(first_info["extension"])( + first_info["fmt"], full_arr, sample_info + ) return parser.data def get_mesh_vertices(tensor_name, index, ret, sample_info, aslist): # we assume that the format of files that we append is the same + if not sample_info: + return ret + if isinstance(sample_info, dict): + sample_info = [sample_info] + if sample_info[0]["extension"] == "stl": + return ret fmt = sample_info[0]["fmt"] ext = sample_info[0]["extension"] parser = mesh_parser.get_mesh_parser(ext)( diff --git a/docs/source/Compressions.rst b/docs/source/Compressions.rst index 6b9ce7cf94..6c4dd3290a 100644 --- a/docs/source/Compressions.rst +++ b/docs/source/Compressions.rst @@ -24,7 +24,7 @@ are given below. +----------------+----------------+----------------------------------------+ | Point Cloud | point_cloud | ``las`` | +----------------+----------------+----------------------------------------+ -| Mesh | mesh | ``ply`` | +| Mesh | mesh | ``ply``, ``stl`` | +----------------+----------------+----------------------------------------+ | Other | bbox, text, | ``lz4`` | | | list, json, | | diff --git a/docs/source/Htypes.rst b/docs/source/Htypes.rst index faa8f22367..334992e261 100644 --- a/docs/source/Htypes.rst +++ b/docs/source/Htypes.rst @@ -134,6 +134,14 @@ Extending with multiple videos >>> ds.videos.extend([deeplake.read(f"videos/00{i}.mp4") for i in range(10)]) +Reading Video samples + +>>> ds.videos[0].numpy() # returns the video frames for the 1st video as a numpy array + +>>> ds.videos[0,9:20,].numpy() # returns the 10-20th frames of the first video as a numpy array + +>>> ds.videos[0].data() # returns a dictionary for the first video with the frame decoded as a numpy array, as well as other video metadata. + .. _audio-htype: Audio Htype @@ -797,7 +805,7 @@ A point cloud tensor can be created using :bluebold:`Examples` -Appending point clouds with numpy arrays +Appending point clouds with numpy arrays. This is only available when sample_compression is set to None. >>> import numpy as np >>> point_cloud1 = np.random.randint(0, 10, (5, 3)) @@ -807,13 +815,19 @@ Appending point clouds with numpy arrays >>> ds.point_clouds.shape >>> (2, None, 3) -Or we can use :meth:`deeplake.read` method to add samples +Or we can use :meth:`deeplake.read` method to add data as files + +>>> ds.point_cloud.append(deeplake.read("example.las")) +>>> ds.point_cloud[0].shape +>>> (100, 3) + +Reading data from a point_cloud + +>>> ds.point_cloud[0].numpy() # returns the point cloud points as a numpy array of shape (num_points, 3) + +>>> ds.point_cloud[0].data() # returns a dictionary with file-dependent keys represneding variety of different data such as: X, Y, Z, intensity, etc. + ->>> import deeplake as dp ->>> sample = dp.read("example.las") # point cloud with 100 points ->>> ds.point_cloud.append(sample) ->>> ds.point_cloud.shape ->>> (1, 100, 3) .. _mesh-htype: @@ -831,13 +845,14 @@ Mesh Htype A mesh tensor can be created using ->>> ds.create_tensor("mesh", htype="mesh", sample_compression="ply") +>>> ds.create_tensor("mesh_ply", htype="mesh", sample_compression="ply") +>>> ds.create_tensor("mesh_stl", htype="mesh", sample_compression="stl") - Optional args: - :ref:`sample_compression ` - Supported compressions: ->>> ["ply"] +>>> ["ply", "stl"] :blue:`Appending meshes` ------------------------ @@ -846,12 +861,22 @@ A mesh tensor can be created using Appending a ply file containing a mesh data to tensor ->>> import deeplake as dp ->>> sample = dp.read("example.ply") # mesh with 100 points and 200 faces ->>> ds.mesh.append(sample) +>>> ds.mesh.append(deeplake.read("example.ply")) +>>> ds.mesh[0].shape +>>> (100, 3) + +Appending a stl file containing a mesh data to tensor + +>>> ds.mesh.append(deeplake.read("example.stl")) >>> ds.mesh.shape ->>> (1, 100, 3) +>>> (100, 3, 3) + +Reading data from a mesh + +>>> ds.mesh[0].numpy() # returns the vertices of the mesh as a numpy array + +>>> ds.mesh[0].data() # returns a dictionary with the following keys: vertices (same as .numpy()), centroids, normals, and sample_info (other metadata) .. _embedding-htype: diff --git a/setup.py b/setup.py index 0a86b65097..1fecf7c0b9 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "google-auth-oauthlib", ], "point_cloud": ["laspy"], + "mesh": ["laspy", "numpy-stl"], } @@ -70,7 +71,7 @@ def libdeeplake_available(): extras_require["all"] = [req_map[r] for r in all_extras] if libdeeplake_available(): - libdeeplake = "libdeeplake==0.0.142" + libdeeplake = "libdeeplake==0.0.144" extras_require["enterprise"] = [libdeeplake, "pyjwt"] extras_require["all"].append(libdeeplake) install_requires.append(libdeeplake)