diff --git a/datum/generator/generator.py b/datum/generator/generator.py index a084337..7f2c0b3 100644 --- a/datum/generator/generator.py +++ b/datum/generator/generator.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +from pathlib import Path from typing import Any, Union from datum.configs import ConfigBase @@ -28,11 +29,16 @@ class DatumGenerator(): Args: path: a path to the data store location. gen_config: configs for generator. + + Raises: + ValueError: If input `path` is Nne or provided path does not exist. """ def __init__(self, path: str, gen_config: Union[AttrDict, ConfigBase] = None): self.path = path self.gen_config = gen_config + if not self.path or (self.path and not Path(self.path).exists()): + raise ValueError("Input `path` does not exits or not provided.") def __call__(self, **kwargs: Any) -> GeneratorReturnType: """Returns a generator to iterate over the processed input data.""" diff --git a/datum/generator/text.py b/datum/generator/text.py index e5db0a3..58298b7 100644 --- a/datum/generator/text.py +++ b/datum/generator/text.py @@ -14,8 +14,8 @@ # limitations under the License. import json -import os from ast import literal_eval +from pathlib import Path from typing import Any import tensorflow as tf @@ -75,13 +75,17 @@ def generate_datum(self, **kwargs: Any) -> GeneratorReturnType: split = kwargs.get('split') if not split: raise ValueError('Pass a valid split name to generate data `__call__` method.') - json_path = kwargs.get('json_path', split + '.json') - with tf.io.gfile.GFile(os.path.join(self.path, json_path)) as json_f: - for key, value in json.load(json_f).items(): - datum = {'text': value['text']} - for label_key, val in value['label'].items(): - try: - datum[label_key] = literal_eval(val) - except ValueError: - datum[label_key] = val - yield key, datum + split_data_files = [ + filename for filename in Path(self.path).iterdir() + if (filename.name.startswith(split) and filename.name.endswith(".json")) + ] + for split_data_file in split_data_files: + with tf.io.gfile.GFile(split_data_file) as json_f: + for key, value in json.load(json_f).items(): + datum = {'text': value['text']} + for label_key, val in value['label'].items(): + try: + datum[label_key] = literal_eval(val) + except Exception: # pylint: disable=broad-except + datum[label_key] = val + yield key, datum diff --git a/datum/writer/tfrecord_writer.py b/datum/writer/tfrecord_writer.py index 718bc85..16779d0 100644 --- a/datum/writer/tfrecord_writer.py +++ b/datum/writer/tfrecord_writer.py @@ -65,7 +65,7 @@ def __init__(self, self.split = split self.sparse_features = sparse_features or [] self.gen_kwargs = gen_kwargs or {} - self.gen_kwargs.update({'split': self.split}) + self.gen_kwargs.update({"split": self.split}) def cache_records(self) -> None: """Write data to cache.""" @@ -74,22 +74,22 @@ def cache_records(self) -> None: total=self.total_examples, leave=False): if self.sparse_features: - logging.debug(f'Adding shapes info to datum for sparse features: {self.sparse_features}.') + logging.debug(f"Adding shapes info to datum for sparse features: {self.sparse_features}.") datum = self.add_shape_fields(datum) serialized_record = self.serializer(datum) self.shuffler.add(key, serialized_record) self.current_examples += 1 - with tf.io.gfile.GFile(os.path.join(self._base_path, 'datum_to_type_and_shape_mapping.json'), - 'w') as js_f: - logging.info(f'Saving datum type and shape metadata to {self._base_path}.') + with tf.io.gfile.GFile(os.path.join(self._base_path, "datum_to_type_and_shape_mapping.json"), + "w") as js_f: + logging.info(f"Saving datum type and shape metadata to {self._base_path}.") types_shapes = datum_to_type_and_shape(datum, self.sparse_features) json.dump(types_shapes, js_f) def create_records(self) -> None: """Create tfrecords from given generator.""" - logging.info('Caching serialized binary example to cache.') + logging.info("Caching serialized binary example to cache.") self.cache_records() - logging.info('Writing data from cache to disk in `.tfrecord` format.') + logging.info("Writing data from cache to disk in `.tfrecord` format.") self.flush() def add_shape_fields(self, datum: DatumType) -> DatumType: @@ -108,7 +108,7 @@ def add_shape_fields(self, datum: DatumType) -> DatumType: if sparse_key in datum: value = np.asarray(datum[sparse_key]) if len(value.shape) >= 2: - new_fields[sparse_key + '_shape'] = list(value.shape) + new_fields[sparse_key + "_shape"] = list(value.shape) datum.update(new_fields) return datum @@ -134,7 +134,7 @@ def flush_records(self) -> Tuple[Dict[str, Dict[str, int]], int]: except DuplicatedKeysError as err: shard_utils.raise_error_for_duplicated_keys(err) shard_info = { - self.split: {spec.path.split('/')[-1]: int(spec.examples_number) + self.split: {spec.path.split("/")[-1]: int(spec.examples_number) for spec in shard_specs} } self.save_shard_info(shard_info) @@ -147,9 +147,9 @@ def save_shard_info(self, shard_info: Dict[str, Dict[str, int]]) -> None: Args: shard_info: input shard info dict. """ - if os.path.isfile(os.path.join(self._base_path, 'shard_info.json')): - with tf.io.gfile.GFile(os.path.join(self._base_path, 'shard_info.json'), 'r') as si_f: + if os.path.isfile(os.path.join(self._base_path, "shard_info.json")): + with tf.io.gfile.GFile(os.path.join(self._base_path, "shard_info.json"), "r") as si_f: prev_shard_info = json.load(si_f) shard_info.update(prev_shard_info) - with tf.io.gfile.GFile(os.path.join(self._base_path, 'shard_info.json'), 'w') as si_f: + with tf.io.gfile.GFile(os.path.join(self._base_path, "shard_info.json"), "w") as si_f: json.dump(shard_info, si_f) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6beca9b..e15a0b8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -125,19 +125,47 @@ class TestTextJsonDataset(absltest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() - _test_create_textjson_records(self.tempdir) - configs = DatasetConfigs() - configs.batch_size_train = 3 - self._dataset = Dataset(self.tempdir, configs) + self.configs = DatasetConfigs() + self.configs.batch_size_train = 3 def tearDown(self): rmtree(self.tempdir) - def test_train_fn(self): + def test_train_fn_json(self): + expected_data = _test_create_textjson_records(self.tempdir, use_two_files=False) + self._dataset = Dataset(self.tempdir, self.configs) ds = self._dataset.train_fn('train', False) batch = next(iter(ds)) self.assertEqual(batch['text'].shape, [3]) self.assertEqual(batch['polarity'].shape, [3]) - np.array_equal(batch['polarity'].numpy(), [1, 2, 0]) - self.assertEqual(list(batch['text'].numpy()), - [b'this is label file', b'this is json file', b'this is text file']) + batch_data = {} + for idx, text_val in enumerate(batch["text"].numpy()): + batch_data[idx] = { + "text": text_val.decode("utf-8"), + "label": { + "polarity": batch["polarity"].numpy()[idx], + "question": batch["question"].numpy()[idx].decode("utf-8"), + } + } + for _, value in batch_data.items(): + assert value in list(expected_data.values()) + + def test_train_fn_two_files(self): + expected_data = _test_create_textjson_records(self.tempdir, use_two_files=True) + self.configs.batch_size_train = 6 + self._dataset = Dataset(self.tempdir, self.configs) + ds = self._dataset.train_fn('train', False) + batch = next(iter(ds)) + self.assertEqual(batch['text'].shape, [6]) + self.assertEqual(batch['polarity'].shape, [6]) + batch_data = {} + for idx, text_val in enumerate(batch["text"].numpy()): + batch_data[idx] = { + "text": text_val.decode("utf-8"), + "label": { + "polarity": batch["polarity"].numpy()[idx], + "question": batch["question"].numpy()[idx].decode("utf-8"), + } + } + for _, value in batch_data.items(): + assert value in list(expected_data.values()) diff --git a/tests/test_generator_text.py b/tests/test_generator_text.py index 831d89b..c8a8810 100644 --- a/tests/test_generator_text.py +++ b/tests/test_generator_text.py @@ -26,34 +26,71 @@ class TestTextJsonDatumGenerator(absltest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() - self.data = { + self.data_1 = { 1: { 'text': 'this is text file', 'label': { - 'polarity': 1 + 'polarity': 1, + 'question': "meaning of this line?", } }, 2: { 'text': 'this is json file', 'label': { - 'polarity': 2 + 'polarity': 2, + 'question': "meaning of this sentence?", } }, 3: { 'text': 'this is label file', 'label': { - 'polarity': 0 + 'polarity': 0, + 'question': "meaning of this para?", } }, } + self.data_2 = { + 4: { + 'text': 'this is next text file', + 'label': { + 'polarity': 4, + 'question': "meaning of next line?", + } + }, + 5: { + 'text': 'this is next json file', + 'label': { + 'polarity': 5, + 'question': "meaning of next sentence?", + } + }, + 6: { + 'text': 'this is next label file', + 'label': { + 'polarity': 6, + 'question': "meaning of next para?", + } + }, + } + self.data = {**self.data_1, **self.data_2} with open(os.path.join(self.tempdir, 'train.json'), 'w') as f: - json.dump(self.data, f) + json.dump(self.data_1, f) self.gen_from_json = text.TextJsonDatumGenerator(self.tempdir) def tearDown(self): rmtree(self.tempdir) def test_generate_datum(self): + for key, datum in self.gen_from_json(split='train'): + self.assertEqual(datum['text'], self.data_1[literal_eval(key)]['text']) + self.assertEqual(datum['polarity'], self.data_1[literal_eval(key)]['label']['polarity']) + self.assertEqual(datum['question'], self.data_1[literal_eval(key)]['label']['question']) + + def test_generate_datum_multiple_files(self): + with open(os.path.join(self.tempdir, 'train_2.json'), 'w') as f: + json.dump(self.data_2, f) + gen_from_json = text.TextJsonDatumGenerator(self.tempdir) for key, datum in self.gen_from_json(split='train'): self.assertEqual(datum['text'], self.data[literal_eval(key)]['text']) self.assertEqual(datum['polarity'], self.data[literal_eval(key)]['label']['polarity']) + self.assertEqual(datum['question'], self.data[literal_eval(key)]['label']['question']) diff --git a/tests/utils.py b/tests/utils.py index c02dc18..57758d3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -78,34 +78,68 @@ def _test_create_seg_records(path): writer.create_records() -def _test_create_textjson_records(path): +def _test_create_textjson_records(path, use_two_files=False): tempdir = tempfile.mkdtemp() data = { 1: { 'text': 'this is text file', 'label': { - 'polarity': 1 + 'polarity': 1, + 'question': "meaning of this line?", } }, 2: { 'text': 'this is json file', 'label': { - 'polarity': 2 + 'polarity': 2, + 'question': "meaning of this sentence?", } }, 3: { 'text': 'this is label file', 'label': { - 'polarity': 0 + 'polarity': 0, + 'question': "meaning of this para?", } }, } + num_examples = 3 + final_data = data + if use_two_files: + data_2 = { + 4: { + 'text': 'this is next text file', + 'label': { + 'polarity': 4, + 'question': "meaning of next line?", + } + }, + 5: { + 'text': 'this is next json file', + 'label': { + 'polarity': 5, + 'question': "meaning of next sentence?", + } + }, + 6: { + 'text': 'this is next label file', + 'label': { + 'polarity': 6, + 'question': "meaning of next para?", + } + }, + } + with open(os.path.join(tempdir, 'train_2.json'), 'w') as f: + json.dump(data_2, f) + num_examples = 6 + final_data = {**data, **data_2} with open(os.path.join(tempdir, 'train.json'), 'w') as f: json.dump(data, f) gen_from_json = text.TextJsonDatumGenerator(tempdir) serializer = DatumSerializer('text') Path(path).mkdir(parents=True, exist_ok=True) textjson_gen = text.TextJsonDatumGenerator(tempdir) - writer = TFRecordWriter(textjson_gen, serializer, path, 'train', 3) + writer = TFRecordWriter(textjson_gen, serializer, path, 'train', num_examples) writer.create_records() rmtree(tempdir) + return final_data