diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 92783b361..0e4a4f2be 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -718,10 +718,14 @@ def decode_example(self, serialized_example): [1], tf.int64, getattr(self._hparams, "sampling_keep_top_k", -1)) if data_items_to_decoders is None: - data_items_to_decoders = { - field: contrib.slim().tfexample_decoder.Tensor(field) - for field in data_fields - } + data_items_to_decoders = {} + for field in data_fields: + if data_fields[field].dtype is tf.string: + default_value = b"" + else: + default_value = 0 + data_items_to_decoders[field] = contrib.slim().tfexample_decoder.Tensor( + field, default_value=default_value) decoder = contrib.slim().tfexample_decoder.TFExampleDecoder( data_fields, data_items_to_decoders)