Skip to content

Commit

Permalink
Refactor download_manager.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676819433
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Sep 20, 2024
1 parent 2c16950 commit 3b0dab2
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 222 deletions.
4 changes: 4 additions & 0 deletions tensorflow_datasets/core/download/checksums.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def _default_checksum_dirs() -> list[epath.Path]:
]


def sha256(str_: str) -> str:
return hashlib.sha256(str_.encode()).hexdigest()


@dataclasses.dataclass(eq=True)
class UrlInfo:
"""Small wrapper around the url metadata (checksum, size).
Expand Down
18 changes: 8 additions & 10 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import concurrent.futures
import dataclasses
import functools
import hashlib
import typing
from typing import Any
import uuid
Expand Down Expand Up @@ -316,8 +315,8 @@ def downloaded_size(self):
"""Returns the total size of downloaded files."""
return sum(url_info.size for url_info in self._recorded_url_infos.values())

def _get_dl_path(self, url: str, sha256: str) -> epath.Path:
return self._download_dir / resource_lib.get_dl_fname(url, sha256)
def _get_dl_path(self, url: str, checksum: str | None = None) -> epath.Path:
return self._download_dir / resource_lib.get_dl_fname(url, checksum)

@property
def register_checksums(self):
Expand Down Expand Up @@ -368,11 +367,9 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
manual_dir=self._manual_dir,
expected_url_info=expected_url_info,
)
url_path = self._get_dl_path(
url, sha256=hashlib.sha256(url.encode('utf-8')).hexdigest()
)
url_path = self._get_dl_path(url)
checksum_path = (
self._get_dl_path(url, sha256=expected_url_info.checksum)
self._get_dl_path(url, expected_url_info.checksum)
if expected_url_info
else None
)
Expand All @@ -392,10 +389,11 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
self._downloader.increase_tqdm(dl_result)
future = promise.Promise.resolve(dl_result)
else:
# Download in an empty tmp directory (to avoid name collisions)
# Download in a tmp directory next to url_path (to avoid name collisions)
# `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
dirname = f'{resource_lib.get_dl_dirname(url)}.tmp.{uuid.uuid4().hex}'
download_tmp_dir = self._download_dir / dirname
download_tmp_dir = (
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
)
download_tmp_dir.mkdir()
logging.info(f'Downloading {url} into {download_tmp_dir}...')
future = self._downloader.download(
Expand Down
Loading

0 comments on commit 3b0dab2

Please sign in to comment.