Skip to content

Commit

Permalink
Add information about blocked versions and configs to dataset_info an…
Browse files Browse the repository at this point in the history
…d restore this information in our ReadOnlyBuilder.

PiperOrigin-RevId: 678169155
  • Loading branch information
The TensorFlow Datasets Authors committed Sep 24, 2024
1 parent 8223a15 commit 551e9d2
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 8 deletions.
12 changes: 12 additions & 0 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,18 @@ def __init__(
self.info.read_from_directory(self._data_dir)
else: # Use the code version (do not restore data)
self.info.initialize_from_bucket()
if self.BLOCKED_VERSIONS is not None:
config_name = self._builder_config.name if self._builder_config else None
if is_blocked := self.BLOCKED_VERSIONS.is_blocked(
version=self._version, config=config_name
):
default_msg = (
f"Dataset {self.name} is blocked at version {self._version} and"
f" config {config_name}."
)
self.info.set_is_blocked(
is_blocked.blocked_msg if is_blocked.blocked_msg else default_msg
)

@utils.classproperty
@classmethod
Expand Down
15 changes: 15 additions & 0 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
alternative_file_formats: (
Sequence[str | file_adapters.FileFormat] | None
) = None,
is_blocked: str | None = None,
# LINT.ThenChange(:setstate)
):
# pyformat: disable
Expand Down Expand Up @@ -243,6 +244,8 @@ def __init__(
split_dict: information about the splits in this dataset.
alternative_file_formats: alternative file formats that are availablefor
this dataset.
is_blocked: A message explaining why the dataset, in its version and
config, is blocked. If empty or None, the dataset is not blocked.
"""
# pyformat: enable
self._builder_or_identity = builder
Expand All @@ -259,6 +262,8 @@ def __init__(
f = file_adapters.FileFormat.from_value(f)
self._alternative_file_formats.append(f)

self._is_blocked = is_blocked

self._info_proto = dataset_info_pb2.DatasetInfo(
name=self._identity.name,
description=utils.dedent(description),
Expand All @@ -276,6 +281,7 @@ def __init__(
alternative_file_formats=[
f.value for f in self._alternative_file_formats
],
is_blocked=self._is_blocked,
)

if homepage:
Expand Down Expand Up @@ -440,6 +446,13 @@ def alternative_file_formats(self) -> Sequence[file_adapters.FileFormat]:
def metadata(self) -> Metadata | None:
return self._metadata

@property
def is_blocked(self) -> str | None:
return self._is_blocked

def set_is_blocked(self, is_blocked: str) -> None:
self._is_blocked = is_blocked

@property
def supervised_keys(self) -> Optional[SupervisedKeysType]:
if not self.as_proto.HasField("supervised_keys"):
Expand Down Expand Up @@ -941,6 +954,7 @@ def __getstate__(self):
"license": self.redistribution_info.license,
"split_dict": self.splits,
"alternative_file_formats": self.alternative_file_formats,
"is_blocked": self.is_blocked,
}
def __setstate__(self, state):
# LINT.IfChange(setstate)
Expand All @@ -956,6 +970,7 @@ def __setstate__(self, state):
license=state["license"],
split_dict=state["split_dict"],
alternative_file_formats=state["alternative_file_formats"],
is_blocked=state["is_blocked"],
)
# LINT.ThenChange(:dataset_info_args)

Expand Down
6 changes: 5 additions & 1 deletion tensorflow_datasets/core/proto/dataset_info.proto
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,9 @@ message DatasetInfo {
// The data that was used to generate this dataset.
repeated DataSourceAccess data_source_accesses = 20;

// Next available: 22
// A message explaining why the dataset is blocked. If empty, it means that
// the dataset is not blocked.
string is_blocked = 23;

// Next available: 24
}
15 changes: 8 additions & 7 deletions tensorflow_datasets/core/proto/dataset_info_generated_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
b' \x01(\t\x12\x10\n\x08\x64\x61ta_dir\x18\x04'
b' \x01(\t\x12\x14\n\x0c\x64s_namespace\x18\x05'
b' \x01(\t\x12\r\n\x05split\x18\x06'
b' \x01(\t"\xd6\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01'
b' \x01(\t"\xea\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01'
b' \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02'
b' \x01(\t\x12\x0f\n\x07version\x18\t \x01(\t\x12I\n\rrelease_notes\x18\x12'
b' \x03(\x0b\x32\x32.tensorflow_datasets.DatasetInfo.ReleaseNotesEntry\x12\x13\n\x0b\x63onfig_name\x18\r'
Expand All @@ -88,7 +88,8 @@
b' \x01(\x08\x12\x13\n\x0b\x66ile_format\x18\x11 \x01(\t\x12'
b' \n\x18\x61lternative_file_formats\x18\x16'
b' \x03(\t\x12\x43\n\x14\x64\x61ta_source_accesses\x18\x14'
b' \x03(\x0b\x32%.tensorflow_datasets.DataSourceAccess\x1a\x33\n\x11ReleaseNotesEntry\x12\x0b\n\x03key\x18\x01'
b' \x03(\x0b\x32%.tensorflow_datasets.DataSourceAccess\x12\x12\n\nis_blocked\x18\x17'
b' \x01(\t\x1a\x33\n\x11ReleaseNotesEntry\x12\x0b\n\x03key\x18\x01'
b' \x01(\t\x12\r\n\x05value\x18\x02'
b' \x01(\t:\x02\x38\x01\x1a\x38\n\x16\x44ownloadChecksumsEntry\x12\x0b\n\x03key\x18\x01'
b' \x01(\t\x12\r\n\x05value\x18\x02'
Expand Down Expand Up @@ -146,9 +147,9 @@
_TFDSDATASETREFERENCE._serialized_start = 1280
_TFDSDATASETREFERENCE._serialized_end = 1404
_DATASETINFO._serialized_start = 1407
_DATASETINFO._serialized_end = 2389
_DATASETINFO_RELEASENOTESENTRY._serialized_start = 2280
_DATASETINFO_RELEASENOTESENTRY._serialized_end = 2331
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2333
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2389
_DATASETINFO._serialized_end = 2409
_DATASETINFO_RELEASENOTESENTRY._serialized_start = 2300
_DATASETINFO_RELEASENOTESENTRY._serialized_end = 2351
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2353
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2409
# @@protoc_insertion_point(module_scope)
21 changes: 21 additions & 0 deletions tensorflow_datasets/core/read_only_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.name = info_proto.name
self.VERSION = version_lib.Version(info_proto.version) # pylint: disable=invalid-name
self.RELEASE_NOTES = info_proto.release_notes or {} # pylint: disable=invalid-name
self.BLOCKED_VERSIONS = self._restore_blocked_versions(info_proto) # pylint: disable=invalid-name

if info_proto.module_name:
# Overwrite the module so documenting `ReadOnlyBuilder` point to the
Expand All @@ -92,6 +93,7 @@ def __init__(
config=builder_config,
version=info_proto.version,
)
self.assert_is_not_blocked()

# For pickling, should come after super.__init__ which is setting that same
# _original_state attribute.
Expand All @@ -103,6 +105,25 @@ def __init__(
'was generated with an old TFDS version (<=3.2.1).'
)

def _restore_blocked_versions(
self, info_proto: dataset_info_pb2.DatasetInfo
) -> version_lib.BlockedVersions | None:
"""Restores the blocked version information from the dataset info proto.
Args:
info_proto: DatasetInfo describing the name, config, etc of the requested
dataset.
Returns:
None if the dataset is not blocked, or a populated BlockedVersions object.
"""
if info_proto.is_blocked:
configs = {
info_proto.version: {info_proto.config_name: info_proto.is_blocked}
}
return version_lib.BlockedVersions(configs=configs)
return None

def _create_builder_config(
self,
builder_config: str | dataset_builder.BuilderConfig | None,
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_datasets/core/read_only_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,29 @@ def test_builder_from_metadata(
assert str(builder.info.features) == str(dummy_features)


def test_restore_blocked_versions(
code_builder: dataset_builder.DatasetBuilder,
dummy_features: features_dict.FeaturesDict,
):
info_proto = dataset_info_pb2.DatasetInfo(
name='abcd',
description='efgh',
config_name='en',
config_description='something',
version='0.1.0',
release_notes={'0.1.0': 'release description'},
citation='some citation',
features=dummy_features.to_proto(),
is_blocked='some reason for blocking',
)
with pytest.raises(
utils.DatasetVariantBlockedError, match='some reason for blocking'
):
read_only_builder.builder_from_metadata(
code_builder.data_dir, info_proto=info_proto
)


def test_builder_from_directory_dir_not_exists(tmp_path: pathlib.Path):
with pytest.raises(FileNotFoundError, match='Could not load dataset info'):
read_only_builder.builder_from_directory(tmp_path)
Expand Down

0 comments on commit 551e9d2

Please sign in to comment.