diff --git a/datum/configs/tfr_configs.py b/datum/configs/tfr_configs.py index 72adf13..366cd8b 100644 --- a/datum/configs/tfr_configs.py +++ b/datum/configs/tfr_configs.py @@ -145,6 +145,63 @@ class DatasetConfigs(ConfigBase): serving.', default_factory=lambda: False, ) + drop_remainder = create_config( + name='drop_remainder', + ty=bool, + docstring='Whether the last batch should be dropped in the case it has fewer than\ + batch_size elements', + default_factory=lambda: True, + ) + drop_remainder_val = create_config( + name='drop_remainder_val', + ty=bool, + docstring='Whether the last batch should be dropped in the case it has fewer than\ + batch_size elements', + default_factory=lambda: False, + ) + drop_remainder_test = create_config( + name='drop_remainder_test', + ty=bool, + docstring='Whether the last batch should be dropped in the case it has fewer than\ + batch_size elements', + default_factory=lambda: False, + ) + deterministic = create_config( + name='deterministic', + ty=bool, + docstring='When `num_parallel_calls` is specified, if this boolean is specified\ + (True or False), it controls the order in which the transformation produces elements.\ + If set to False, the transformation is allowed to yield elements out of order to trade\ + determinism for performance. If not specified, the `tf.data.Options.deterministic`\ + option (True by default) controls the behavior.', + default_factory=lambda: False, + ) + deterministic_val = create_config( + name='deterministic_val', + ty=bool, + docstring='When `num_parallel_calls` is specified, if this boolean is specified\ + (True or False), it controls the order in which the transformation produces elements.\ + If set to False, the transformation is allowed to yield elements out of order to trade\ + determinism for performance. If not specified, the `tf.data.Options.deterministic`\ + option (True by default) controls the behavior.', + default_factory=lambda: False, + ) + deterministic_test = create_config( + name='deterministic_test', + ty=bool, + docstring='When `num_parallel_calls` is specified, if this boolean is specified\ + (True or False), it controls the order in which the transformation produces elements.\ + If set to False, the transformation is allowed to yield elements out of order to trade\ + determinism for performance. If not specified, the `tf.data.Options.deterministic`\ + option (True by default) controls the behavior.', + default_factory=lambda: False, + ) + num_parallel_calls = create_config( + name='num_parallel_calls', + ty=int, + docstring='The number of batches to compute asynchronously in parallel.', + default_factory=lambda: 1, + ) batch_size_train = create_config( name='batch_size_train', ty=int, diff --git a/datum/reader/dataset.py b/datum/reader/dataset.py index 4fde60f..f542ee0 100644 --- a/datum/reader/dataset.py +++ b/datum/reader/dataset.py @@ -56,6 +56,8 @@ def _read(self, shuffle: bool = False, echoing: Optional[int] = None, full_dataset: bool = False, + drop_remainder: bool = False, + deterministic: bool = False, pre_batching_callback: Optional[Callable[[Dict], Dict]] = None, post_batching_callback: Optional[Callable[[Dict], Dict]] = None) -> DatasetType: """Read and process data from tfrecord files. @@ -70,6 +72,13 @@ def _read(self, shuffle: whether to shuffle examples in the dataset. echoing: batch echoing factor, if not None perform batch_echoing. full_dataset: if true, return the dataset as a single batch for dataset with single element. + drop_remainder: Whether the last batch should be dropped in the case it has fewer than + `batch_size` elements. + deterministic: When `num_parallel_calls` is specified, if this boolean is specified + (True or False), it controls the order in which the transformation produces elements. + If set to False, the transformation is allowed to yield elements out of order to trade + determinism for performance. If not specified, the `tf.data.Options.deterministic` + option (True by default) controls the behavior. pre_batching_callback: data processing to apply before batching. post_batching_callback: data processing to apply post batching. This fucntion should support batch processsing. @@ -93,16 +102,23 @@ def _read(self, if bucket_fn: logging.info( f'Using bucketing to batch data, bucket_params: {self._dataset_configs.bucket_op}') - bucket_op = tf.data.experimental.bucket_by_sequence_length( + dataset = dataset.bucket_by_sequence_length( bucket_fn, self._dataset_configs.bucket_op.bucket_boundaries, self._dataset_configs.bucket_op.bucket_batch_sizes, padded_shapes=tf.compat.v1.data.get_output_shapes(dataset), padding_values=None, + drop_remainder=drop_remainder, pad_to_bucket_boundary=False) - dataset = dataset.apply(bucket_op) - elif batch_size: - dataset = dataset.padded_batch(batch_size, padded_shapes=self.padded_shapes) + elif batch_size and not deterministic: + dataset = dataset.padded_batch(batch_size, + padded_shapes=self.padded_shapes, + drop_remainder=drop_remainder) + elif batch_size and deterministic: + dataset = dataset.batch(batch_size, + drop_remainder=drop_remainder, + num_parallel_calls=self._dataset_configs.num_parallel_calls, + deterministic=deterministic) if echoing: dataset = dataset.flat_map( lambda example: tf.data.Dataset.from_tensors(example).repeat(echoing)) @@ -152,6 +168,8 @@ def train_fn(self, shuffle=shuffle, echoing=self._dataset_configs.echoing, full_dataset=self._dataset_configs.full_dataset, + drop_remainder=self._dataset_configs.drop_remainder, + deterministic=self._dataset_configs.deterministic, pre_batching_callback=self._dataset_configs.pre_batching_callback_train, post_batching_callback=self._dataset_configs.post_batching_callback_train) @@ -176,6 +194,8 @@ def val_fn(self, shuffle=shuffle, echoing=None, full_dataset=self._dataset_configs.full_dataset, + drop_remainder=self._dataset_configs.drop_remainder_val, + deterministic=self._dataset_configs.deterministic_val, pre_batching_callback=self._dataset_configs.pre_batching_callback_val, post_batching_callback=self._dataset_configs.post_batching_callback_val) @@ -200,5 +220,7 @@ def test_fn(self, shuffle=shuffle, echoing=None, full_dataset=self._dataset_configs.full_dataset, + drop_remainder=self._dataset_configs.drop_remainder_test, + deterministic=self._dataset_configs.deterministic_test, pre_batching_callback=self._dataset_configs.pre_batching_callback_test, post_batching_callback=self._dataset_configs.post_batching_callback_test) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2750d73..6beca9b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -31,6 +31,9 @@ def setUp(self): configs = DatasetConfigs() configs.batch_size_train = 1 configs.batch_size_val = 1 + configs.drop_remainder = True + configs.drop_remainder_val = True + configs.deterministic_val = True self._dataset = Dataset(self.tempdir, configs) def tearDown(self):