Skip to content

Commit

Permalink
Fix typehints in download modules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672965431
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Sep 10, 2024
1 parent 1b8b37a commit a691e0a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 100 deletions.
27 changes: 7 additions & 20 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def replace(self, **kwargs: Any) -> DownloadConfig:
return dataclasses.replace(self, **kwargs)


class DownloadManager(object):
class DownloadManager:
"""Manages the download and extraction of files, as well as caching.
Downloaded files are cached under `download_dir`. The file name of downloaded
Expand Down Expand Up @@ -353,8 +353,9 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
"""
# Normalize the input
if isinstance(resource, str):
resource = resource_lib.Resource(url=resource)
url = resource.url
url = resource
else:
url = resource.url
assert url is not None, 'URL is undefined from resource.'

expected_url_info = self._url_infos.get(url)
Expand Down Expand Up @@ -500,7 +501,7 @@ def _rename_and_get_final_dl_path(
elif path == url_path:
if checksum_path:
# Checksums were registered: Rename -> checksums_path
resource_lib.rename_info_file(path, checksum_path, overwrite=True)
resource_lib.replace_info_file(path, checksum_path)
return path.replace(checksum_path)
else:
# Checksums not registered: -> do nothing
Expand All @@ -522,7 +523,7 @@ def _rename_and_get_final_dl_path(
@utils.memoize()
def _extract(self, resource: ExtractPath) -> promise.Promise[epath.Path]:
"""Extract a single archive, returns Promise->path to extraction result."""
if isinstance(resource, epath.PathLikeCls):
if not isinstance(resource, resource_lib.Resource):
resource = resource_lib.Resource(path=resource)
path = resource.path
extract_method = resource.extract_method
Expand Down Expand Up @@ -613,7 +614,7 @@ def iter_archive(
Returns:
Generator yielding tuple (path_within_archive, file_obj).
"""
if isinstance(resource, epath.PathLikeCls):
if not isinstance(resource, resource_lib.Resource):
resource = resource_lib.Resource(path=resource)
return extractor.iter_archive(resource.path, resource.extract_method)

Expand Down Expand Up @@ -763,20 +764,6 @@ def _validate_checksums(
raise NonMatchingChecksumError(msg)


def _read_url_info(url_path: epath.PathLike) -> checksums.UrlInfo:
"""Loads the `UrlInfo` from the `.INFO` file."""
file_info = resource_lib.read_info_file(url_path)
if 'url_info' not in file_info:
raise ValueError(
'Could not find `url_info` in {}. This likely indicates that '
'the files where downloaded with a previous version of TFDS (<=3.1.0). '
)
url_info = file_info['url_info']
url_info.setdefault('filename', None)
url_info['size'] = utils.Size(url_info['size'])
return checksums.UrlInfo(**url_info)


def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/download/download_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __setitem__(self, key, value):
return super().__setitem__(os.fspath(key), epath.Path(value))


class Artifact(object):
class Artifact:
# For testing only.

def __init__(self, name, url=None, content=None):
Expand Down
40 changes: 20 additions & 20 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
return _Downloader(*args, **kwargs)


def _read_url_info(url_path: epath.PathLike) -> checksums_lib.UrlInfo:
def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
"""Loads the `UrlInfo` from the `.INFO` file."""
file_info = resource_lib.read_info_file(url_path)
if 'url_info' not in file_info:
Expand Down Expand Up @@ -172,7 +172,7 @@ def _get_filename(response: Response) -> str:
return utils.basename_from_url(response.url)


class _Downloader(object):
class _Downloader:
"""Class providing async download API with checksum validation.
Do not instantiate this class directly. Instead, call `get_downloader()`.
Expand All @@ -192,9 +192,8 @@ def __init__(
"""Init _Downloader instance.
Args:
max_simultaneous_downloads: `int`, optional max number of simultaneous
downloads. If None then it defaults to
`self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
max_simultaneous_downloads: Optional max number of simultaneous downloads.
If None then it defaults to `self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
checksumer: `hashlib.HASH`. Defaults to `hashlib.sha256`.
"""
self._executor = concurrent.futures.ThreadPoolExecutor(
Expand Down Expand Up @@ -227,18 +226,18 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None:

def download(
self, url: str, destination_path: str, verify: bool = True
) -> 'promise.Promise[concurrent.futures.Future[DownloadResult]]':
) -> promise.Promise[concurrent.futures.Future[DownloadResult]]:
"""Download url to given path.
Returns Promise -> sha256 of downloaded file.
Args:
url: address of resource to download.
destination_path: `str`, path to directory where to download the resource.
verify: whether to verify ssl certificates
url: Address of resource to download.
destination_path: Path to directory where to download the resource.
verify: Whether to verify ssl certificates
Returns:
Promise obj -> (`str`, int): (downloaded object checksum, size in bytes).
Promise obj -> Download result.
"""
destination_path = os.fspath(destination_path)
self._pbar_url.update_total(1)
Expand All @@ -250,19 +249,19 @@ def download(
def _sync_file_copy(
self,
filepath: str,
destination_path: str,
destination_path: epath.Path,
) -> DownloadResult:
"""Downloads the file through `tf.io.gfile` API."""
filename = os.path.basename(filepath)
out_path = os.path.join(destination_path, filename)
out_path = destination_path / filename
tf.io.gfile.copy(filepath, out_path)
url_info = checksums_lib.compute_url_info(
out_path, checksum_cls=self._checksumer_cls
)
self._pbar_dl_size.update_total(url_info.size)
self._pbar_dl_size.update(url_info.size)
self._pbar_url.update(1)
return DownloadResult(path=epath.Path(out_path), url_info=url_info)
return DownloadResult(path=out_path, url_info=url_info)

def _sync_download(
self, url: str, destination_path: str, verify: bool = True
Expand All @@ -275,16 +274,17 @@ def _sync_download(
https://requests.readthedocs.io/en/master/user/advanced/#proxies
Args:
url: url to download
destination_path: path where to write it
verify: whether to verify ssl certificates
url: Url to download.
destination_path: Path where to write it.
verify: Whether to verify ssl certificates.
Returns:
None
Download result.
Raises:
DownloadError: when download fails.
"""
destination_path = epath.Path(destination_path)
try:
# If url is on a filesystem that gfile understands, use copy. Otherwise,
# use requests (http) or urllib (ftp).
Expand All @@ -295,15 +295,15 @@ def _sync_download(

with _open_url(url, verify=verify) as (response, iter_content):
fname = _get_filename(response)
path = os.path.join(destination_path, fname)
path = destination_path / fname
size = 0

# Initialize the download size progress bar
size_mb = 0
unit_mb = units.MiB
total_size = int(response.headers.get('Content-length', 0)) // unit_mb
self._pbar_dl_size.update_total(total_size)
with tf.io.gfile.GFile(path, 'wb') as file_:
with path.open('wb') as file_:
checksum = self._checksumer_cls()
for block in iter_content:
size += len(block)
Expand All @@ -317,7 +317,7 @@ def _sync_download(
size_mb %= unit_mb
self._pbar_url.update(1)
return DownloadResult(
path=epath.Path(path),
path=path,
url_info=checksums_lib.UrlInfo(
checksum=checksum.hexdigest(),
size=utils.Size(size),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/download/downloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensorflow_datasets.core.download import util


class _FakeResponse(object):
class _FakeResponse:

def __init__(self, url, content, cookies=None, headers=None, status_code=200):
self.url = url
Expand Down
Loading

0 comments on commit a691e0a

Please sign in to comment.