From c37ca97de8b0522e0c025cc164ced743d2fbd214 Mon Sep 17 00:00:00 2001 From: Tim Semenov Date: Fri, 27 Sep 2024 11:21:47 -0700 Subject: [PATCH] Refactor download_manager.py PiperOrigin-RevId: 679662316 --- .../core/download/download_manager.py | 230 +++++++++--------- .../core/download/downloader.py | 49 +--- 2 files changed, 115 insertions(+), 164 deletions(-) diff --git a/tensorflow_datasets/core/download/download_manager.py b/tensorflow_datasets/core/download/download_manager.py index 0855b8e6a21..740fd401ece 100644 --- a/tensorflow_datasets/core/download/download_manager.py +++ b/tensorflow_datasets/core/download/download_manager.py @@ -355,15 +355,11 @@ def _get_manually_downloaded_path( # processed once, even if passed twice to download_manager. @utils.build_synchronize_decorator() @utils.memoize() - def _download(self, resource: Url) -> promise.Promise[epath.Path]: + def _download_or_get_cache( + self, resource: Url + ) -> promise.Promise[epath.Path]: """Downloads resource or gets downloaded cache. - This function: - - 1. Reuse cache (`_get_cached_path`) or download the file - 2. Register or validate checksums (`_register_or_validate_checksums`) - 3. Rename download to final path (`_rename_and_get_final_dl_path`) - Args: resource: The URL to download. @@ -378,76 +374,79 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: expected_url_info = self._url_infos.get(url) - # 3 possible destinations for the path: - # * In `manual_dir` (manually downloaded data) - # * In `downloads/url_path` (checksum unknown) - # * In `downloads/checksum_path` (checksum registered) - manually_downloaded_path = self._get_manually_downloaded_path( - expected_url_info=expected_url_info - ) - url_path = self._get_dl_path(resource) - checksum_path = ( - self._get_dl_path(resource, expected_url_info.checksum) - if expected_url_info - else None - ) - - # Get the cached path and url_info (if they exists) - dl_result = downloader.get_cached_path( - manually_downloaded_path=manually_downloaded_path, - checksum_path=checksum_path, - url_path=url_path, - expected_url_info=expected_url_info, - ) - if dl_result and not self._force_download: # Download was cached - logging.info( - f'Skipping download of {url}: File cached in {dl_result.path}' + # User has manually downloaded the file. + if manually_downloaded_path := self._get_manually_downloaded_path( + expected_url_info + ): + computed_url_info = checksums.compute_url_info(manually_downloaded_path) + self._register_or_validate_checksums( + resource=resource, + path=manually_downloaded_path, + computed_url_info=computed_url_info, ) - # Still update the progression bar to indicate the file was downloaded - self._downloader.increase_tqdm(dl_result.url_info) - future = promise.Promise.resolve(dl_result) - else: - # Download in a tmp directory next to url_path (to avoid name collisions) - # `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path` - download_tmp_dir = ( - url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}' + self._log_skip_download( + url=url, url_info=computed_url_info, path=manually_downloaded_path ) - download_tmp_dir.mkdir() - logging.info(f'Downloading {url} into {download_tmp_dir}...') - future = self._downloader.download( - url, download_tmp_dir, verify=self._verify_ssl + return promise.Promise.resolve(manually_downloaded_path) + + # Force download + elif self._force_download: + return self._download(resource) + + # Download has been cached (checksum known) + elif expected_url_info and resource_lib.Resource.exists_locally( + checksum_path := self._get_dl_path(resource, expected_url_info.checksum) + ): + self._register_or_validate_checksums( + resource=resource, + path=checksum_path, + computed_url_info=expected_url_info, ) + self._log_skip_download( + url=url, url_info=expected_url_info, path=checksum_path + ) + return promise.Promise.resolve(checksum_path) + + # Download has been cached (checksum unknown) + elif resource_lib.Resource.exists_locally( + url_path := self._get_dl_path(resource) + ): + computed_url_info = downloader.read_url_info(url_path) + if expected_url_info and expected_url_info != computed_url_info: + # If checksums are registered but do not match, trigger a new + # download (e.g. previous file corrupted, checksums updated) + return self._download(resource) + if checksum_path := self._register_or_validate_checksums( + resource=resource, path=url_path, computed_url_info=computed_url_info + ): + # Checksums were registered: Rename -> checksum_path + resource_lib.replace_info_file(url_path, checksum_path) + path = url_path.replace(checksum_path) + else: + # Checksums not registered: -> do nothing + path = url_path + self._log_skip_download(url=url, url_info=computed_url_info, path=path) + return promise.Promise.resolve(path) - # Post-process the result - return future.then( - lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda - resource=resource, - path=dl_result.path, - computed_url_info=dl_result.url_info, - expected_url_info=expected_url_info, - checksum_path=checksum_path, - url_path=url_path, - ) - ) + # Cache not found + else: + return self._download(resource) + + def _log_skip_download( + self, url: str, url_info: checksums.UrlInfo, path: epath.Path + ) -> None: + logging.info(f'Skipping download of {url}: File cached in {path}') + # Still update the progression bar to indicate the file was downloaded + self._downloader.increase_tqdm(url_info) def _register_or_validate_checksums( self, resource: resource_lib.Resource, path: epath.Path, - expected_url_info: checksums.UrlInfo | None, computed_url_info: checksums.UrlInfo, - checksum_path: epath.Path | None, - url_path: epath.Path, - ) -> epath.Path: - """Validates/records checksums and renames final downloaded path.""" - # `path` can be: - # * Manually downloaded - # * (cached) checksum_path - # * (cached) url_path - # * `tmp_dir/file` (downloaded path) - + ) -> epath.Path | None: + """Validates/records checksums and returns checksum path if registered.""" url: str = resource.url # pytype: disable=annotation-type-mismatch - # Used both in `.downloaded_size` and `_record_url_infos()` self._recorded_url_infos[url] = computed_url_info if self._register_checksums: @@ -457,12 +456,9 @@ def _register_or_validate_checksums( # * `register_checksums_path` was validated in `__init__` so this # shouldn't fail. self._record_url_infos() - - # Checksum path should now match the new registered checksum (even if - # checksums were previously registered) - expected_url_info = computed_url_info - checksum_path = self._get_dl_path(resource, computed_url_info.checksum) + return self._get_dl_path(resource, computed_url_info.checksum) else: + expected_url_info = self._url_infos.get(url) # Eventually validate checksums # Note: # * If path is cached at `url_path` but cached @@ -478,15 +474,8 @@ def _register_or_validate_checksums( computed_url_info=computed_url_info, path=path, ) - - return self._rename_and_get_final_dl_path( - url=url, - path=path, - expected_url_info=expected_url_info, - computed_url_info=computed_url_info, - checksum_path=checksum_path, - url_path=url_path, - ) + if expected_url_info: + return self._get_dl_path(resource, expected_url_info.checksum) def _validate_checksums( self, @@ -517,47 +506,56 @@ def _validate_checksums( ) raise NonMatchingChecksumError(msg) - def _rename_and_get_final_dl_path( - self, - url: str, - path: epath.Path, - expected_url_info: checksums.UrlInfo | None, - computed_url_info: checksums.UrlInfo | None, - checksum_path: epath.Path | None, - url_path: epath.Path, - ) -> epath.Path: - """Eventually rename the downloaded file if checksums were recorded.""" - # `path` can be: - # * Manually downloaded - # * (cached) checksum_path - # * (cached) url_path - # * `tmp_dir/file` (downloaded path) - if self._manual_dir and path.is_relative_to(self._manual_dir): - return path # Manually downloaded data - elif path == checksum_path: # Path already at final destination - assert computed_url_info == expected_url_info # Sanity check - return checksum_path # pytype: disable=bad-return-type - elif path == url_path: - if checksum_path: - # Checksums were registered: Rename -> checksums_path - resource_lib.replace_info_file(path, checksum_path) - return path.replace(checksum_path) - else: - # Checksums not registered: -> do nothing - return path - else: # Path was downloaded in tmp dir - dst_path = checksum_path or url_path + def _download( + self, resource: resource_lib.Resource + ) -> promise.Promise[epath.Path]: + """Downloads resource. + + Args: + resource: The resource to download. + + Returns: + Promise of the path to the downloaded url. + """ + url_path = self._get_dl_path(resource) + url: str = resource.url # pytype: disable=annotation-type-mismatch + + # Download in a tmp directory next to url_path (to avoid name collisions) + # `download_tmp_dir` is cleaned-up in `callback` + download_tmp_dir = ( + url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}' + ) + download_tmp_dir.mkdir() + logging.info(f'Downloading {url} into {download_tmp_dir}...') + future = self._downloader.download( + url, download_tmp_dir, verify=self._verify_ssl + ) + + def callback(dl_result: downloader.DownloadResult) -> epath.Path: + """Post-process the download result.""" + dl_path = dl_result.path + dl_url_info = dl_result.url_info + + dst_path = self._register_or_validate_checksums( + resource=resource, computed_url_info=dl_url_info, path=dl_path + ) + if not dst_path: + dst_path = url_path + resource_lib.write_info_file( url=url, path=dst_path, dataset_name=self._dataset_name, - original_fname=path.name, - url_info=computed_url_info, + original_fname=dl_path.name, + url_info=dl_url_info, ) - path.replace(dst_path) - path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty) + dl_path.replace(dst_path) + dl_path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty) + return dst_path + return future.then(callback) + @utils.build_synchronize_decorator() @utils.memoize() def _extract(self, resource: ExtractPath) -> promise.Promise[epath.Path]: @@ -587,7 +585,7 @@ def callback(path): resource.path = path return self._extract(resource) - return self._download(resource).then(callback) + return self._download_or_get_cache(resource).then(callback) def download_checksums(self, checksums_url): """Downloads checksum file from the given URL and adds it to registry.""" @@ -636,7 +634,7 @@ def download(self, url_or_urls): """ # Add progress bar to follow the download state with self._downloader.tqdm(): - return _map_promise(self._download, url_or_urls) + return _map_promise(self._download_or_get_cache, url_or_urls) def iter_archive( self, diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index ce089127cd7..a6ac604481f 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader': return _Downloader(*args, **kwargs) -def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo: +def read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo: """Loads the `UrlInfo` from the `.INFO` file.""" file_info = resource_lib.read_info_file(url_path) if 'url_info' not in file_info: @@ -75,53 +75,6 @@ def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo: return checksums_lib.UrlInfo(**url_info) -def get_cached_path( - manually_downloaded_path: epath.Path | None, - checksum_path: epath.Path | None, - url_path: epath.Path, - expected_url_info: checksums_lib.UrlInfo | None, -) -> DownloadResult | None: - """Returns the downloaded path and computed url-info. - - If the path is not cached, or that `url_path` does not match checksums, - the file will be downloaded again. - - Path can be cached at three different locations: - - Args: - manually_downloaded_path: Manually downloaded in `dl_manager.manual_dir` - checksum_path: Cached in the final destination (if checksum known) - url_path: Cached in the tmp destination (if checksum unknown). - expected_url_info: Registered checksum (if known) - """ - # User has manually downloaded the file. - if manually_downloaded_path and manually_downloaded_path.exists(): - computed_url_info = checksums_lib.compute_url_info(manually_downloaded_path) - return DownloadResult( - path=manually_downloaded_path, url_info=computed_url_info - ) - - # Download has been cached (checksum known) - elif checksum_path and resource_lib.Resource.exists_locally(checksum_path): - # `path = f(checksum)` was found, so url_info match - return DownloadResult(checksum_path, url_info=expected_url_info) - - # Download has been cached (checksum unknown) - elif resource_lib.Resource.exists_locally(url_path): - # Info restored from `.INFO` file - computed_url_info = _read_url_info(url_path) - # If checksums are now registered but do not match, trigger a new - # download (e.g. previous file corrupted, checksums updated) - if expected_url_info and computed_url_info != expected_url_info: - return None - else: - return DownloadResult(path=url_path, url_info=computed_url_info) - - # Else file not found (or has bad checksums). (re)download. - else: - return None - - def _filename_from_content_disposition( content_disposition: str, ) -> str | None: