From fc31737a8cc5b36728edafd814038947c158cbc7 Mon Sep 17 00:00:00 2001 From: Tim Semenov Date: Fri, 20 Sep 2024 07:29:36 -0700 Subject: [PATCH] Use epath.Path in downloader.py PiperOrigin-RevId: 676837944 --- .../core/download/download_manager.py | 121 +++++++++--------- .../core/download/downloader.py | 6 +- .../core/download/downloader_test.py | 35 +++-- 3 files changed, 75 insertions(+), 87 deletions(-) diff --git a/tensorflow_datasets/core/download/download_manager.py b/tensorflow_datasets/core/download/download_manager.py index c317cd8c1ad..f210f26c3d3 100644 --- a/tensorflow_datasets/core/download/download_manager.py +++ b/tensorflow_datasets/core/download/download_manager.py @@ -297,7 +297,7 @@ def __getstate__(self): return state @property - def _downloader(self): + def _downloader(self) -> downloader._Downloader: if not self.__downloader: self.__downloader = get_downloader( max_simultaneous_downloads=self._max_simultaneous_downloads @@ -305,13 +305,13 @@ def _downloader(self): return self.__downloader @property - def _extractor(self): + def _extractor(self) -> extractor._Extractor: if not self.__extractor: self.__extractor = extractor.get_extractor() return self.__extractor @property - def downloaded_size(self): + def downloaded_size(self) -> int: """Returns the total size of downloaded files.""" return sum(url_info.size for url_info in self._recorded_url_infos.values()) @@ -331,6 +331,22 @@ def _record_url_infos(self): self._recorded_url_infos, ) + def _get_manually_downloaded_path( + self, expected_url_info: checksums.UrlInfo | None + ) -> epath.Path | None: + """Checks if file is already downloaded in manual_dir.""" + if not self._manual_dir: # Manual dir not passed + return None + + if not expected_url_info or not expected_url_info.filename: + return None # Filename unknown. + + manual_path = self._manual_dir / expected_url_info.filename + if not manual_path.exists(): # File not manually downloaded + return None + + return manual_path + # Synchronize and memoize decorators ensure same resource will only be # processed once, even if passed twice to download_manager. @utils.build_synchronize_decorator() @@ -363,9 +379,8 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: # * In `manual_dir` (manually downloaded data) # * In `downloads/url_path` (checksum unknown) # * In `downloads/checksum_path` (checksum registered) - manually_downloaded_path = _get_manually_downloaded_path( - manual_dir=self._manual_dir, - expected_url_info=expected_url_info, + manually_downloaded_path = self._get_manually_downloaded_path( + expected_url_info=expected_url_info ) url_path = self._get_dl_path(url) checksum_path = ( @@ -459,12 +474,11 @@ def _register_or_validate_checksums( # the download isn't cached (re-running build will retrigger a new # download). This is expected as it might mean the downloaded file # was corrupted. Note: The tmp file isn't deleted to allow inspection. - _validate_checksums( + self._validate_checksums( url=url, path=path, expected_url_info=expected_url_info, computed_url_info=computed_url_info, - force_checksums_validation=self._force_checksums_validation, ) return self._rename_and_get_final_dl_path( @@ -476,6 +490,42 @@ def _register_or_validate_checksums( url_path=url_path, ) + def _validate_checksums( + self, + url: str, + path: epath.Path, + computed_url_info: checksums.UrlInfo | None, + expected_url_info: checksums.UrlInfo | None, + ) -> None: + """Validate computed_url_info match expected_url_info.""" + # If force-checksums validations, both expected and computed url_info + # should exists + if self._force_checksums_validation: + # Checksum of the downloaded file unknown (for manually downloaded file) + if not computed_url_info: + computed_url_info = checksums.compute_url_info(path) + # Checksums have not been registered + if not expected_url_info: + raise ValueError( + f'Missing checksums url: {url}, yet ' + '`force_checksums_validation=True`. ' + 'Did you forget to register checksums?' + ) + + if ( + expected_url_info + and computed_url_info + and expected_url_info != computed_url_info + ): + msg = ( + f'Artifact {url}, downloaded to {path}, has wrong checksum:\n' + f'* Expected: {expected_url_info}\n' + f'* Got: {computed_url_info}\n' + 'To debug, see: ' + 'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror' + ) + raise NonMatchingChecksumError(msg) + def _rename_and_get_final_dl_path( self, url: str, @@ -707,61 +757,6 @@ def manual_dir(self) -> epath.Path: return self._manual_dir -def _get_manually_downloaded_path( - manual_dir: epath.Path | None, - expected_url_info: checksums.UrlInfo | None, -) -> epath.Path | None: - """Checks if file is already downloaded in manual_dir.""" - if not manual_dir: # Manual dir not passed - return None - - if not expected_url_info or not expected_url_info.filename: - return None # Filename unknown. - - manual_path = manual_dir / expected_url_info.filename - if not manual_path.exists(): # File not manually downloaded - return None - - return manual_path - - -def _validate_checksums( - url: str, - path: epath.Path, - computed_url_info: checksums.UrlInfo | None, - expected_url_info: checksums.UrlInfo | None, - force_checksums_validation: bool, -) -> None: - """Validate computed_url_info match expected_url_info.""" - # If force-checksums validations, both expected and computed url_info - # should exists - if force_checksums_validation: - # Checksum of the downloaded file unknown (for manually downloaded file) - if not computed_url_info: - computed_url_info = checksums.compute_url_info(path) - # Checksums have not been registered - if not expected_url_info: - raise ValueError( - f'Missing checksums url: {url}, yet ' - '`force_checksums_validation=True`. ' - 'Did you forget to register checksums?' - ) - - if ( - expected_url_info - and computed_url_info - and expected_url_info != computed_url_info - ): - msg = ( - f'Artifact {url}, downloaded to {path}, has wrong checksum:\n' - f'* Expected: {expected_url_info}\n' - f'* Got: {computed_url_info}\n' - 'To debug, see: ' - 'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror' - ) - raise NonMatchingChecksumError(msg) - - def _map_promise(map_fn, all_inputs): """Map the function into each element and resolve the promise.""" all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index 3faa0deda70..40219802ffd 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -225,7 +225,7 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None: self._pbar_dl_size.update(dl_result.url_info.size) def download( - self, url: str, destination_path: str, verify: bool = True + self, url: str, destination_path: epath.Path, verify: bool = True ) -> promise.Promise[concurrent.futures.Future[DownloadResult]]: """Download url to given path. @@ -239,7 +239,6 @@ def download( Returns: Promise obj -> Download result. """ - destination_path = os.fspath(destination_path) self._pbar_url.update_total(1) future = self._executor.submit( self._sync_download, url, destination_path, verify @@ -264,7 +263,7 @@ def _sync_file_copy( return DownloadResult(path=out_path, url_info=url_info) def _sync_download( - self, url: str, destination_path: str, verify: bool = True + self, url: str, destination_path: epath.Path, verify: bool = True ) -> DownloadResult: """Synchronous version of `download` method. @@ -284,7 +283,6 @@ def _sync_download( Raises: DownloadError: when download fails. """ - destination_path = epath.Path(destination_path) try: # If url is on a filesystem that gfile understands, use copy. Otherwise, # use requests (http) or urllib (ftp). diff --git a/tensorflow_datasets/core/download/downloader_test.py b/tensorflow_datasets/core/download/downloader_test.py index 517d256393f..550039e1e6d 100644 --- a/tensorflow_datasets/core/download/downloader_test.py +++ b/tensorflow_datasets/core/download/downloader_test.py @@ -13,17 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for downloader.""" - import hashlib import io -import os -import tempfile from typing import Optional from unittest import mock +from etils import epath import pytest -import tensorflow as tf from tensorflow_datasets import testing from tensorflow_datasets.core.download import downloader from tensorflow_datasets.core.download import resource as resource_lib @@ -59,11 +55,13 @@ def setUp(self): super(DownloaderTest, self).setUp() self.addCleanup(mock.patch.stopall) self.downloader = downloader.get_downloader(10, hashlib.sha256) - self.tmp_dir = tempfile.mkdtemp(dir=tf.compat.v1.test.get_temp_dir()) + self.tmp_dir = epath.Path(self.tmp_dir) self.url = 'http://example.com/foo.tar.gz' self.resource = resource_lib.Resource(url=self.url) - self.path = os.path.join(self.tmp_dir, 'foo.tar.gz') - self.incomplete_path = '%s.incomplete' % self.path + self.path = self.tmp_dir / 'foo.tar.gz' + self.incomplete_path = self.path.with_suffix( + self.path.suffix + '.incomplete' + ) self.response = b'This \nis an \nawesome\n response!' self.resp_checksum = hashlib.sha256(self.response).hexdigest() self.cookies = {} @@ -84,22 +82,20 @@ def test_ok(self): promise = self.downloader.download(self.url, self.tmp_dir) future = promise.get() url_info = future.url_info - self.assertEqual(self.path, os.fspath(future.path)) + self.assertEqual(self.path, future.path) self.assertEqual(url_info.checksum, self.resp_checksum) - with tf.io.gfile.GFile(self.path, 'rb') as result: - self.assertEqual(result.read(), self.response) - self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) + self.assertEqual(self.path.read_bytes(), self.response) + self.assertFalse(self.incomplete_path.exists()) def test_drive_no_cookies(self): url = 'https://drive.google.com/uc?export=download&id=a1b2bc3' promise = self.downloader.download(url, self.tmp_dir) future = promise.get() url_info = future.url_info - self.assertEqual(self.path, os.fspath(future.path)) + self.assertEqual(self.path, future.path) self.assertEqual(url_info.checksum, self.resp_checksum) - with tf.io.gfile.GFile(self.path, 'rb') as result: - self.assertEqual(result.read(), self.response) - self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) + self.assertEqual(self.path.read_bytes(), self.response) + self.assertFalse(self.incomplete_path.exists()) def test_drive(self): self.cookies = {'foo': 'bar', 'download_warning_a': 'token', 'a': 'b'} @@ -129,11 +125,10 @@ def test_ftp(self): promise = self.downloader.download(url, self.tmp_dir) future = promise.get() url_info = future.url_info - self.assertEqual(self.path, os.fspath(future.path)) + self.assertEqual(self.path, future.path) self.assertEqual(url_info.checksum, self.resp_checksum) - with tf.io.gfile.GFile(self.path, 'rb') as result: - self.assertEqual(result.read(), self.response) - self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) + self.assertEqual(self.path.read_bytes(), self.response) + self.assertFalse(self.incomplete_path.exists()) def test_ftp_error(self): error = downloader.urllib.error.URLError('Problem serving file.')