diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 0537af9db81..4a8012ee473 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -294,6 +294,18 @@ def __init__( self.info.read_from_directory(self._data_dir) else: # Use the code version (do not restore data) self.info.initialize_from_bucket() + if self.BLOCKED_VERSIONS is not None: + config_name = self._builder_config.name if self._builder_config else None + if is_blocked := self.BLOCKED_VERSIONS.is_blocked( + version=self._version, config=config_name + ): + default_msg = ( + f"Dataset {self.name} is blocked at version {self._version} and" + f" config {config_name}." + ) + self.info.set_is_blocked( + is_blocked.blocked_msg if is_blocked.blocked_msg else default_msg + ) @utils.classproperty @classmethod diff --git a/tensorflow_datasets/core/dataset_info.py b/tensorflow_datasets/core/dataset_info.py index 121b79862ed..2325175292e 100644 --- a/tensorflow_datasets/core/dataset_info.py +++ b/tensorflow_datasets/core/dataset_info.py @@ -197,6 +197,7 @@ def __init__( alternative_file_formats: ( Sequence[str | file_adapters.FileFormat] | None ) = None, + is_blocked: str | None = None, # LINT.ThenChange(:setstate) ): # pyformat: disable @@ -243,6 +244,8 @@ def __init__( split_dict: information about the splits in this dataset. alternative_file_formats: alternative file formats that are availablefor this dataset. + is_blocked: A message explaining why the dataset, in its version and + config, is blocked. If empty or None, the dataset is not blocked. """ # pyformat: enable self._builder_or_identity = builder @@ -259,6 +262,8 @@ def __init__( f = file_adapters.FileFormat.from_value(f) self._alternative_file_formats.append(f) + self._is_blocked = is_blocked + self._info_proto = dataset_info_pb2.DatasetInfo( name=self._identity.name, description=utils.dedent(description), @@ -276,6 +281,7 @@ def __init__( alternative_file_formats=[ f.value for f in self._alternative_file_formats ], + is_blocked=self._is_blocked, ) if homepage: @@ -440,6 +446,13 @@ def alternative_file_formats(self) -> Sequence[file_adapters.FileFormat]: def metadata(self) -> Metadata | None: return self._metadata + @property + def is_blocked(self) -> str | None: + return self._is_blocked + + def set_is_blocked(self, is_blocked: str) -> None: + self._is_blocked = is_blocked + @property def supervised_keys(self) -> Optional[SupervisedKeysType]: if not self.as_proto.HasField("supervised_keys"): @@ -941,6 +954,7 @@ def __getstate__(self): "license": self.redistribution_info.license, "split_dict": self.splits, "alternative_file_formats": self.alternative_file_formats, + "is_blocked": self.is_blocked, } def __setstate__(self, state): # LINT.IfChange(setstate) @@ -956,6 +970,7 @@ def __setstate__(self, state): license=state["license"], split_dict=state["split_dict"], alternative_file_formats=state["alternative_file_formats"], + is_blocked=state["is_blocked"], ) # LINT.ThenChange(:dataset_info_args) diff --git a/tensorflow_datasets/core/proto/dataset_info.proto b/tensorflow_datasets/core/proto/dataset_info.proto index 6ce3f2ff4d9..bd77191562c 100644 --- a/tensorflow_datasets/core/proto/dataset_info.proto +++ b/tensorflow_datasets/core/proto/dataset_info.proto @@ -231,5 +231,9 @@ message DatasetInfo { // The data that was used to generate this dataset. repeated DataSourceAccess data_source_accesses = 20; - // Next available: 22 + // A message explaining why the dataset is blocked. If empty, it means that + // the dataset is not blocked. + string is_blocked = 23; + + // Next available: 24 } diff --git a/tensorflow_datasets/core/proto/dataset_info_generated_pb2.py b/tensorflow_datasets/core/proto/dataset_info_generated_pb2.py index 4fc5d2ec9f7..96058616c64 100644 --- a/tensorflow_datasets/core/proto/dataset_info_generated_pb2.py +++ b/tensorflow_datasets/core/proto/dataset_info_generated_pb2.py @@ -67,7 +67,7 @@ b' \x01(\t\x12\x10\n\x08\x64\x61ta_dir\x18\x04' b' \x01(\t\x12\x14\n\x0c\x64s_namespace\x18\x05' b' \x01(\t\x12\r\n\x05split\x18\x06' - b' \x01(\t"\xd6\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01' + b' \x01(\t"\xea\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01' b' \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02' b' \x01(\t\x12\x0f\n\x07version\x18\t \x01(\t\x12I\n\rrelease_notes\x18\x12' b' \x03(\x0b\x32\x32.tensorflow_datasets.DatasetInfo.ReleaseNotesEntry\x12\x13\n\x0b\x63onfig_name\x18\r' @@ -88,7 +88,8 @@ b' \x01(\x08\x12\x13\n\x0b\x66ile_format\x18\x11 \x01(\t\x12' b' \n\x18\x61lternative_file_formats\x18\x16' b' \x03(\t\x12\x43\n\x14\x64\x61ta_source_accesses\x18\x14' - b' \x03(\x0b\x32%.tensorflow_datasets.DataSourceAccess\x1a\x33\n\x11ReleaseNotesEntry\x12\x0b\n\x03key\x18\x01' + b' \x03(\x0b\x32%.tensorflow_datasets.DataSourceAccess\x12\x12\n\nis_blocked\x18\x17' + b' \x01(\t\x1a\x33\n\x11ReleaseNotesEntry\x12\x0b\n\x03key\x18\x01' b' \x01(\t\x12\r\n\x05value\x18\x02' b' \x01(\t:\x02\x38\x01\x1a\x38\n\x16\x44ownloadChecksumsEntry\x12\x0b\n\x03key\x18\x01' b' \x01(\t\x12\r\n\x05value\x18\x02' @@ -146,9 +147,9 @@ _TFDSDATASETREFERENCE._serialized_start = 1280 _TFDSDATASETREFERENCE._serialized_end = 1404 _DATASETINFO._serialized_start = 1407 - _DATASETINFO._serialized_end = 2389 - _DATASETINFO_RELEASENOTESENTRY._serialized_start = 2280 - _DATASETINFO_RELEASENOTESENTRY._serialized_end = 2331 - _DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2333 - _DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2389 + _DATASETINFO._serialized_end = 2409 + _DATASETINFO_RELEASENOTESENTRY._serialized_start = 2300 + _DATASETINFO_RELEASENOTESENTRY._serialized_end = 2351 + _DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2353 + _DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2409 # @@protoc_insertion_point(module_scope) diff --git a/tensorflow_datasets/core/read_only_builder.py b/tensorflow_datasets/core/read_only_builder.py index 54b9a00faef..fba213f7bf3 100644 --- a/tensorflow_datasets/core/read_only_builder.py +++ b/tensorflow_datasets/core/read_only_builder.py @@ -78,6 +78,7 @@ def __init__( self.name = info_proto.name self.VERSION = version_lib.Version(info_proto.version) # pylint: disable=invalid-name self.RELEASE_NOTES = info_proto.release_notes or {} # pylint: disable=invalid-name + self.BLOCKED_VERSIONS = self._restore_blocked_versions(info_proto) # pylint: disable=invalid-name if info_proto.module_name: # Overwrite the module so documenting `ReadOnlyBuilder` point to the @@ -92,6 +93,7 @@ def __init__( config=builder_config, version=info_proto.version, ) + self.assert_is_not_blocked() # For pickling, should come after super.__init__ which is setting that same # _original_state attribute. @@ -103,6 +105,25 @@ def __init__( 'was generated with an old TFDS version (<=3.2.1).' ) + def _restore_blocked_versions( + self, info_proto: dataset_info_pb2.DatasetInfo + ) -> version_lib.BlockedVersions | None: + """Restores the blocked version information from the dataset info proto. + + Args: + info_proto: DatasetInfo describing the name, config, etc of the requested + dataset. + + Returns: + None if the dataset is not blocked, or a populated BlockedVersions object. + """ + if info_proto.is_blocked: + configs = { + info_proto.version: {info_proto.config_name: info_proto.is_blocked} + } + return version_lib.BlockedVersions(configs=configs) + return None + def _create_builder_config( self, builder_config: str | dataset_builder.BuilderConfig | None, diff --git a/tensorflow_datasets/core/read_only_builder_test.py b/tensorflow_datasets/core/read_only_builder_test.py index b023b8cf3ae..ad004852955 100644 --- a/tensorflow_datasets/core/read_only_builder_test.py +++ b/tensorflow_datasets/core/read_only_builder_test.py @@ -246,6 +246,29 @@ def test_builder_from_metadata( assert str(builder.info.features) == str(dummy_features) +def test_restore_blocked_versions( + code_builder: dataset_builder.DatasetBuilder, + dummy_features: features_dict.FeaturesDict, +): + info_proto = dataset_info_pb2.DatasetInfo( + name='abcd', + description='efgh', + config_name='en', + config_description='something', + version='0.1.0', + release_notes={'0.1.0': 'release description'}, + citation='some citation', + features=dummy_features.to_proto(), + is_blocked='some reason for blocking', + ) + with pytest.raises( + utils.DatasetVariantBlockedError, match='some reason for blocking' + ): + read_only_builder.builder_from_metadata( + code_builder.data_dir, info_proto=info_proto + ) + + def test_builder_from_directory_dir_not_exists(tmp_path: pathlib.Path): with pytest.raises(FileNotFoundError, match='Could not load dataset info'): read_only_builder.builder_from_directory(tmp_path)