Skip to content

Commit

Permalink
Refactor download_manager.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679662316
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Sep 27, 2024
1 parent da34559 commit c37ca97
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 164 deletions.
230 changes: 114 additions & 116 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,15 +355,11 @@ def _get_manually_downloaded_path(
# processed once, even if passed twice to download_manager.
@utils.build_synchronize_decorator()
@utils.memoize()
def _download(self, resource: Url) -> promise.Promise[epath.Path]:
def _download_or_get_cache(
self, resource: Url
) -> promise.Promise[epath.Path]:
"""Downloads resource or gets downloaded cache.
This function:
1. Reuse cache (`_get_cached_path`) or download the file
2. Register or validate checksums (`_register_or_validate_checksums`)
3. Rename download to final path (`_rename_and_get_final_dl_path`)
Args:
resource: The URL to download.
Expand All @@ -378,76 +374,79 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:

expected_url_info = self._url_infos.get(url)

# 3 possible destinations for the path:
# * In `manual_dir` (manually downloaded data)
# * In `downloads/url_path` (checksum unknown)
# * In `downloads/checksum_path` (checksum registered)
manually_downloaded_path = self._get_manually_downloaded_path(
expected_url_info=expected_url_info
)
url_path = self._get_dl_path(resource)
checksum_path = (
self._get_dl_path(resource, expected_url_info.checksum)
if expected_url_info
else None
)

# Get the cached path and url_info (if they exists)
dl_result = downloader.get_cached_path(
manually_downloaded_path=manually_downloaded_path,
checksum_path=checksum_path,
url_path=url_path,
expected_url_info=expected_url_info,
)
if dl_result and not self._force_download: # Download was cached
logging.info(
f'Skipping download of {url}: File cached in {dl_result.path}'
# User has manually downloaded the file.
if manually_downloaded_path := self._get_manually_downloaded_path(
expected_url_info
):
computed_url_info = checksums.compute_url_info(manually_downloaded_path)
self._register_or_validate_checksums(
resource=resource,
path=manually_downloaded_path,
computed_url_info=computed_url_info,
)
# Still update the progression bar to indicate the file was downloaded
self._downloader.increase_tqdm(dl_result.url_info)
future = promise.Promise.resolve(dl_result)
else:
# Download in a tmp directory next to url_path (to avoid name collisions)
# `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
download_tmp_dir = (
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
self._log_skip_download(
url=url, url_info=computed_url_info, path=manually_downloaded_path
)
download_tmp_dir.mkdir()
logging.info(f'Downloading {url} into {download_tmp_dir}...')
future = self._downloader.download(
url, download_tmp_dir, verify=self._verify_ssl
return promise.Promise.resolve(manually_downloaded_path)

# Force download
elif self._force_download:
return self._download(resource)

# Download has been cached (checksum known)
elif expected_url_info and resource_lib.Resource.exists_locally(
checksum_path := self._get_dl_path(resource, expected_url_info.checksum)
):
self._register_or_validate_checksums(
resource=resource,
path=checksum_path,
computed_url_info=expected_url_info,
)
self._log_skip_download(
url=url, url_info=expected_url_info, path=checksum_path
)
return promise.Promise.resolve(checksum_path)

# Download has been cached (checksum unknown)
elif resource_lib.Resource.exists_locally(
url_path := self._get_dl_path(resource)
):
computed_url_info = downloader.read_url_info(url_path)
if expected_url_info and expected_url_info != computed_url_info:
# If checksums are registered but do not match, trigger a new
# download (e.g. previous file corrupted, checksums updated)
return self._download(resource)
if checksum_path := self._register_or_validate_checksums(
resource=resource, path=url_path, computed_url_info=computed_url_info
):
# Checksums were registered: Rename -> checksum_path
resource_lib.replace_info_file(url_path, checksum_path)
path = url_path.replace(checksum_path)
else:
# Checksums not registered: -> do nothing
path = url_path
self._log_skip_download(url=url, url_info=computed_url_info, path=path)
return promise.Promise.resolve(path)

# Post-process the result
return future.then(
lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda
resource=resource,
path=dl_result.path,
computed_url_info=dl_result.url_info,
expected_url_info=expected_url_info,
checksum_path=checksum_path,
url_path=url_path,
)
)
# Cache not found
else:
return self._download(resource)

def _log_skip_download(
self, url: str, url_info: checksums.UrlInfo, path: epath.Path
) -> None:
logging.info(f'Skipping download of {url}: File cached in {path}')
# Still update the progression bar to indicate the file was downloaded
self._downloader.increase_tqdm(url_info)

def _register_or_validate_checksums(
self,
resource: resource_lib.Resource,
path: epath.Path,
expected_url_info: checksums.UrlInfo | None,
computed_url_info: checksums.UrlInfo,
checksum_path: epath.Path | None,
url_path: epath.Path,
) -> epath.Path:
"""Validates/records checksums and renames final downloaded path."""
# `path` can be:
# * Manually downloaded
# * (cached) checksum_path
# * (cached) url_path
# * `tmp_dir/file` (downloaded path)

) -> epath.Path | None:
"""Validates/records checksums and returns checksum path if registered."""
url: str = resource.url # pytype: disable=annotation-type-mismatch
# Used both in `.downloaded_size` and `_record_url_infos()`
self._recorded_url_infos[url] = computed_url_info

if self._register_checksums:
Expand All @@ -457,12 +456,9 @@ def _register_or_validate_checksums(
# * `register_checksums_path` was validated in `__init__` so this
# shouldn't fail.
self._record_url_infos()

# Checksum path should now match the new registered checksum (even if
# checksums were previously registered)
expected_url_info = computed_url_info
checksum_path = self._get_dl_path(resource, computed_url_info.checksum)
return self._get_dl_path(resource, computed_url_info.checksum)
else:
expected_url_info = self._url_infos.get(url)
# Eventually validate checksums
# Note:
# * If path is cached at `url_path` but cached
Expand All @@ -478,15 +474,8 @@ def _register_or_validate_checksums(
computed_url_info=computed_url_info,
path=path,
)

return self._rename_and_get_final_dl_path(
url=url,
path=path,
expected_url_info=expected_url_info,
computed_url_info=computed_url_info,
checksum_path=checksum_path,
url_path=url_path,
)
if expected_url_info:
return self._get_dl_path(resource, expected_url_info.checksum)

def _validate_checksums(
self,
Expand Down Expand Up @@ -517,47 +506,56 @@ def _validate_checksums(
)
raise NonMatchingChecksumError(msg)

def _rename_and_get_final_dl_path(
self,
url: str,
path: epath.Path,
expected_url_info: checksums.UrlInfo | None,
computed_url_info: checksums.UrlInfo | None,
checksum_path: epath.Path | None,
url_path: epath.Path,
) -> epath.Path:
"""Eventually rename the downloaded file if checksums were recorded."""
# `path` can be:
# * Manually downloaded
# * (cached) checksum_path
# * (cached) url_path
# * `tmp_dir/file` (downloaded path)
if self._manual_dir and path.is_relative_to(self._manual_dir):
return path # Manually downloaded data
elif path == checksum_path: # Path already at final destination
assert computed_url_info == expected_url_info # Sanity check
return checksum_path # pytype: disable=bad-return-type
elif path == url_path:
if checksum_path:
# Checksums were registered: Rename -> checksums_path
resource_lib.replace_info_file(path, checksum_path)
return path.replace(checksum_path)
else:
# Checksums not registered: -> do nothing
return path
else: # Path was downloaded in tmp dir
dst_path = checksum_path or url_path
def _download(
self, resource: resource_lib.Resource
) -> promise.Promise[epath.Path]:
"""Downloads resource.
Args:
resource: The resource to download.
Returns:
Promise of the path to the downloaded url.
"""
url_path = self._get_dl_path(resource)
url: str = resource.url # pytype: disable=annotation-type-mismatch

# Download in a tmp directory next to url_path (to avoid name collisions)
# `download_tmp_dir` is cleaned-up in `callback`
download_tmp_dir = (
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
)
download_tmp_dir.mkdir()
logging.info(f'Downloading {url} into {download_tmp_dir}...')
future = self._downloader.download(
url, download_tmp_dir, verify=self._verify_ssl
)

def callback(dl_result: downloader.DownloadResult) -> epath.Path:
"""Post-process the download result."""
dl_path = dl_result.path
dl_url_info = dl_result.url_info

dst_path = self._register_or_validate_checksums(
resource=resource, computed_url_info=dl_url_info, path=dl_path
)
if not dst_path:
dst_path = url_path

resource_lib.write_info_file(
url=url,
path=dst_path,
dataset_name=self._dataset_name,
original_fname=path.name,
url_info=computed_url_info,
original_fname=dl_path.name,
url_info=dl_url_info,
)
path.replace(dst_path)
path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)
dl_path.replace(dst_path)
dl_path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)

return dst_path

return future.then(callback)

@utils.build_synchronize_decorator()
@utils.memoize()
def _extract(self, resource: ExtractPath) -> promise.Promise[epath.Path]:
Expand Down Expand Up @@ -587,7 +585,7 @@ def callback(path):
resource.path = path
return self._extract(resource)

return self._download(resource).then(callback)
return self._download_or_get_cache(resource).then(callback)

def download_checksums(self, checksums_url):
"""Downloads checksum file from the given URL and adds it to registry."""
Expand Down Expand Up @@ -636,7 +634,7 @@ def download(self, url_or_urls):
"""
# Add progress bar to follow the download state
with self._downloader.tqdm():
return _map_promise(self._download, url_or_urls)
return _map_promise(self._download_or_get_cache, url_or_urls)

def iter_archive(
self,
Expand Down
49 changes: 1 addition & 48 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
return _Downloader(*args, **kwargs)


def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
def read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
"""Loads the `UrlInfo` from the `.INFO` file."""
file_info = resource_lib.read_info_file(url_path)
if 'url_info' not in file_info:
Expand All @@ -75,53 +75,6 @@ def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
return checksums_lib.UrlInfo(**url_info)


def get_cached_path(
manually_downloaded_path: epath.Path | None,
checksum_path: epath.Path | None,
url_path: epath.Path,
expected_url_info: checksums_lib.UrlInfo | None,
) -> DownloadResult | None:
"""Returns the downloaded path and computed url-info.
If the path is not cached, or that `url_path` does not match checksums,
the file will be downloaded again.
Path can be cached at three different locations:
Args:
manually_downloaded_path: Manually downloaded in `dl_manager.manual_dir`
checksum_path: Cached in the final destination (if checksum known)
url_path: Cached in the tmp destination (if checksum unknown).
expected_url_info: Registered checksum (if known)
"""
# User has manually downloaded the file.
if manually_downloaded_path and manually_downloaded_path.exists():
computed_url_info = checksums_lib.compute_url_info(manually_downloaded_path)
return DownloadResult(
path=manually_downloaded_path, url_info=computed_url_info
)

# Download has been cached (checksum known)
elif checksum_path and resource_lib.Resource.exists_locally(checksum_path):
# `path = f(checksum)` was found, so url_info match
return DownloadResult(checksum_path, url_info=expected_url_info)

# Download has been cached (checksum unknown)
elif resource_lib.Resource.exists_locally(url_path):
# Info restored from `.INFO` file
computed_url_info = _read_url_info(url_path)
# If checksums are now registered but do not match, trigger a new
# download (e.g. previous file corrupted, checksums updated)
if expected_url_info and computed_url_info != expected_url_info:
return None
else:
return DownloadResult(path=url_path, url_info=computed_url_info)

# Else file not found (or has bad checksums). (re)download.
else:
return None


def _filename_from_content_disposition(
content_disposition: str,
) -> str | None:
Expand Down

0 comments on commit c37ca97

Please sign in to comment.