Skip to content

Commit

Permalink
add support for drop_remainder and deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
n3011 committed Jan 7, 2022
1 parent 2dfc8c6 commit 3fd27f6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
57 changes: 57 additions & 0 deletions datum/configs/tfr_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,63 @@ class DatasetConfigs(ConfigBase):
serving.',
default_factory=lambda: False,
)
drop_remainder = create_config(
name='drop_remainder',
ty=bool,
docstring='Whether the last batch should be dropped in the case it has fewer than\
batch_size elements',
default_factory=lambda: True,
)
drop_remainder_val = create_config(
name='drop_remainder_val',
ty=bool,
docstring='Whether the last batch should be dropped in the case it has fewer than\
batch_size elements',
default_factory=lambda: False,
)
drop_remainder_test = create_config(
name='drop_remainder_test',
ty=bool,
docstring='Whether the last batch should be dropped in the case it has fewer than\
batch_size elements',
default_factory=lambda: False,
)
deterministic = create_config(
name='deterministic',
ty=bool,
docstring='When `num_parallel_calls` is specified, if this boolean is specified\
(True or False), it controls the order in which the transformation produces elements.\
If set to False, the transformation is allowed to yield elements out of order to trade\
determinism for performance. If not specified, the `tf.data.Options.deterministic`\
option (True by default) controls the behavior.',
default_factory=lambda: False,
)
deterministic_val = create_config(
name='deterministic_val',
ty=bool,
docstring='When `num_parallel_calls` is specified, if this boolean is specified\
(True or False), it controls the order in which the transformation produces elements.\
If set to False, the transformation is allowed to yield elements out of order to trade\
determinism for performance. If not specified, the `tf.data.Options.deterministic`\
option (True by default) controls the behavior.',
default_factory=lambda: False,
)
deterministic_test = create_config(
name='deterministic_test',
ty=bool,
docstring='When `num_parallel_calls` is specified, if this boolean is specified\
(True or False), it controls the order in which the transformation produces elements.\
If set to False, the transformation is allowed to yield elements out of order to trade\
determinism for performance. If not specified, the `tf.data.Options.deterministic`\
option (True by default) controls the behavior.',
default_factory=lambda: False,
)
num_parallel_calls = create_config(
name='num_parallel_calls',
ty=int,
docstring='The number of batches to compute asynchronously in parallel.',
default_factory=lambda: 1,
)
batch_size_train = create_config(
name='batch_size_train',
ty=int,
Expand Down
30 changes: 26 additions & 4 deletions datum/reader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _read(self,
shuffle: bool = False,
echoing: Optional[int] = None,
full_dataset: bool = False,
drop_remainder: bool = False,
deterministic: bool = False,
pre_batching_callback: Optional[Callable[[Dict], Dict]] = None,
post_batching_callback: Optional[Callable[[Dict], Dict]] = None) -> DatasetType:
"""Read and process data from tfrecord files.
Expand All @@ -70,6 +72,13 @@ def _read(self,
shuffle: whether to shuffle examples in the dataset.
echoing: batch echoing factor, if not None perform batch_echoing.
full_dataset: if true, return the dataset as a single batch for dataset with single element.
drop_remainder: Whether the last batch should be dropped in the case it has fewer than
`batch_size` elements.
deterministic: When `num_parallel_calls` is specified, if this boolean is specified
(True or False), it controls the order in which the transformation produces elements.
If set to False, the transformation is allowed to yield elements out of order to trade
determinism for performance. If not specified, the `tf.data.Options.deterministic`
option (True by default) controls the behavior.
pre_batching_callback: data processing to apply before batching.
post_batching_callback: data processing to apply post batching. This fucntion should support
batch processsing.
Expand All @@ -93,16 +102,23 @@ def _read(self,
if bucket_fn:
logging.info(
f'Using bucketing to batch data, bucket_params: {self._dataset_configs.bucket_op}')
bucket_op = tf.data.experimental.bucket_by_sequence_length(
dataset = dataset.bucket_by_sequence_length(
bucket_fn,
self._dataset_configs.bucket_op.bucket_boundaries,
self._dataset_configs.bucket_op.bucket_batch_sizes,
padded_shapes=tf.compat.v1.data.get_output_shapes(dataset),
padding_values=None,
drop_remainder=drop_remainder,
pad_to_bucket_boundary=False)
dataset = dataset.apply(bucket_op)
elif batch_size:
dataset = dataset.padded_batch(batch_size, padded_shapes=self.padded_shapes)
elif batch_size and not deterministic:
dataset = dataset.padded_batch(batch_size,
padded_shapes=self.padded_shapes,
drop_remainder=drop_remainder)
elif batch_size and deterministic:
dataset = dataset.batch(batch_size,
drop_remainder=drop_remainder,
num_parallel_calls=self._dataset_configs.num_parallel_calls,
deterministic=deterministic)
if echoing:
dataset = dataset.flat_map(
lambda example: tf.data.Dataset.from_tensors(example).repeat(echoing))
Expand Down Expand Up @@ -152,6 +168,8 @@ def train_fn(self,
shuffle=shuffle,
echoing=self._dataset_configs.echoing,
full_dataset=self._dataset_configs.full_dataset,
drop_remainder=self._dataset_configs.drop_remainder,
deterministic=self._dataset_configs.deterministic,
pre_batching_callback=self._dataset_configs.pre_batching_callback_train,
post_batching_callback=self._dataset_configs.post_batching_callback_train)

Expand All @@ -176,6 +194,8 @@ def val_fn(self,
shuffle=shuffle,
echoing=None,
full_dataset=self._dataset_configs.full_dataset,
drop_remainder=self._dataset_configs.drop_remainder_val,
deterministic=self._dataset_configs.deterministic_val,
pre_batching_callback=self._dataset_configs.pre_batching_callback_val,
post_batching_callback=self._dataset_configs.post_batching_callback_val)

Expand All @@ -200,5 +220,7 @@ def test_fn(self,
shuffle=shuffle,
echoing=None,
full_dataset=self._dataset_configs.full_dataset,
drop_remainder=self._dataset_configs.drop_remainder_test,
deterministic=self._dataset_configs.deterministic_test,
pre_batching_callback=self._dataset_configs.pre_batching_callback_test,
post_batching_callback=self._dataset_configs.post_batching_callback_test)
3 changes: 3 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def setUp(self):
configs = DatasetConfigs()
configs.batch_size_train = 1
configs.batch_size_val = 1
configs.drop_remainder = True
configs.drop_remainder_val = True
configs.deterministic_val = True
self._dataset = Dataset(self.tempdir, configs)

def tearDown(self):
Expand Down

0 comments on commit 3fd27f6

Please sign in to comment.