From 870d16a5f0c095a3e58896de628a04fe198d7191 Mon Sep 17 00:00:00 2001 From: Tim Semenov Date: Thu, 21 Nov 2024 09:35:02 -0800 Subject: [PATCH] Simplify `ReadFromTFDS`. PiperOrigin-RevId: 698813440 --- tensorflow_datasets/core/beam_utils.py | 8 ++------ tensorflow_datasets/core/beam_utils_test.py | 2 -- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tensorflow_datasets/core/beam_utils.py b/tensorflow_datasets/core/beam_utils.py index 51e33c04250..b0434e83c70 100644 --- a/tensorflow_datasets/core/beam_utils.py +++ b/tensorflow_datasets/core/beam_utils.py @@ -19,7 +19,6 @@ from typing import Any from tensorflow_datasets.core import dataset_builder -from tensorflow_datasets.core import lazy_imports_lib from tensorflow_datasets.core import naming from tensorflow_datasets.core.utils import shard_utils from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam @@ -30,15 +29,13 @@ ] -@lazy_imports_lib.beam_ptransform_fn def ReadFromTFDS( # pylint: disable=invalid-name - pipeline, builder: dataset_builder.DatasetBuilder, split: str, workers_per_shard: int = 1, **as_dataset_kwargs: Any, ): - """Creates a beam pipeline yielding TFDS examples. + """Creates a beam PCollection yielding TFDS examples. Each dataset shard will be processed in parallel. @@ -63,7 +60,6 @@ def ReadFromTFDS( # pylint: disable=invalid-name examples will be used. Args: - pipeline: beam pipeline (automatically set) builder: Dataset builder to load split: Split name to load (e.g. `train+test`, `train`) workers_per_shard: number of workers that should read a shard in parallel. @@ -132,7 +128,7 @@ def load_shard(file_instruction: shard_utils.FileInstruction): # pylint: disabl value=len(file_instructions), namespace='ReadFromTFDS', ) - return pipeline | beam.Create(file_instructions) | beam.FlatMap(load_shard) + return beam.Create(file_instructions) | beam.FlatMap(load_shard) @functools.lru_cache(None) diff --git a/tensorflow_datasets/core/beam_utils_test.py b/tensorflow_datasets/core/beam_utils_test.py index 13a87d34135..16c67ae1c07 100644 --- a/tensorflow_datasets/core/beam_utils_test.py +++ b/tensorflow_datasets/core/beam_utils_test.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for beam_utils.""" - import os import pathlib from typing import Optional