Skip to content

Commit

Permalink
Improved support Text dataset including GLUE datasets. (#21)
Browse files Browse the repository at this point in the history
* support for multiple json files in text dataset

* update texts for text ds
  • Loading branch information
n3011 authored Aug 15, 2022
1 parent 306055b commit 4f196bc
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 41 deletions.
6 changes: 6 additions & 0 deletions datum/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import abc
from pathlib import Path
from typing import Any, Union

from datum.configs import ConfigBase
Expand All @@ -28,11 +29,16 @@ class DatumGenerator():
Args:
path: a path to the data store location.
gen_config: configs for generator.
Raises:
ValueError: If input `path` is Nne or provided path does not exist.
"""

def __init__(self, path: str, gen_config: Union[AttrDict, ConfigBase] = None):
self.path = path
self.gen_config = gen_config
if not self.path or (self.path and not Path(self.path).exists()):
raise ValueError("Input `path` does not exits or not provided.")

def __call__(self, **kwargs: Any) -> GeneratorReturnType:
"""Returns a generator to iterate over the processed input data."""
Expand Down
26 changes: 15 additions & 11 deletions datum/generator/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

import json
import os
from ast import literal_eval
from pathlib import Path
from typing import Any

import tensorflow as tf
Expand Down Expand Up @@ -75,13 +75,17 @@ def generate_datum(self, **kwargs: Any) -> GeneratorReturnType:
split = kwargs.get('split')
if not split:
raise ValueError('Pass a valid split name to generate data `__call__` method.')
json_path = kwargs.get('json_path', split + '.json')
with tf.io.gfile.GFile(os.path.join(self.path, json_path)) as json_f:
for key, value in json.load(json_f).items():
datum = {'text': value['text']}
for label_key, val in value['label'].items():
try:
datum[label_key] = literal_eval(val)
except ValueError:
datum[label_key] = val
yield key, datum
split_data_files = [
filename for filename in Path(self.path).iterdir()
if (filename.name.startswith(split) and filename.name.endswith(".json"))
]
for split_data_file in split_data_files:
with tf.io.gfile.GFile(split_data_file) as json_f:
for key, value in json.load(json_f).items():
datum = {'text': value['text']}
for label_key, val in value['label'].items():
try:
datum[label_key] = literal_eval(val)
except Exception: # pylint: disable=broad-except
datum[label_key] = val
yield key, datum
24 changes: 12 additions & 12 deletions datum/writer/tfrecord_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,
self.split = split
self.sparse_features = sparse_features or []
self.gen_kwargs = gen_kwargs or {}
self.gen_kwargs.update({'split': self.split})
self.gen_kwargs.update({"split": self.split})

def cache_records(self) -> None:
"""Write data to cache."""
Expand All @@ -74,22 +74,22 @@ def cache_records(self) -> None:
total=self.total_examples,
leave=False):
if self.sparse_features:
logging.debug(f'Adding shapes info to datum for sparse features: {self.sparse_features}.')
logging.debug(f"Adding shapes info to datum for sparse features: {self.sparse_features}.")
datum = self.add_shape_fields(datum)
serialized_record = self.serializer(datum)
self.shuffler.add(key, serialized_record)
self.current_examples += 1
with tf.io.gfile.GFile(os.path.join(self._base_path, 'datum_to_type_and_shape_mapping.json'),
'w') as js_f:
logging.info(f'Saving datum type and shape metadata to {self._base_path}.')
with tf.io.gfile.GFile(os.path.join(self._base_path, "datum_to_type_and_shape_mapping.json"),
"w") as js_f:
logging.info(f"Saving datum type and shape metadata to {self._base_path}.")
types_shapes = datum_to_type_and_shape(datum, self.sparse_features)
json.dump(types_shapes, js_f)

def create_records(self) -> None:
"""Create tfrecords from given generator."""
logging.info('Caching serialized binary example to cache.')
logging.info("Caching serialized binary example to cache.")
self.cache_records()
logging.info('Writing data from cache to disk in `.tfrecord` format.')
logging.info("Writing data from cache to disk in `.tfrecord` format.")
self.flush()

def add_shape_fields(self, datum: DatumType) -> DatumType:
Expand All @@ -108,7 +108,7 @@ def add_shape_fields(self, datum: DatumType) -> DatumType:
if sparse_key in datum:
value = np.asarray(datum[sparse_key])
if len(value.shape) >= 2:
new_fields[sparse_key + '_shape'] = list(value.shape)
new_fields[sparse_key + "_shape"] = list(value.shape)
datum.update(new_fields)
return datum

Expand All @@ -134,7 +134,7 @@ def flush_records(self) -> Tuple[Dict[str, Dict[str, int]], int]:
except DuplicatedKeysError as err:
shard_utils.raise_error_for_duplicated_keys(err)
shard_info = {
self.split: {spec.path.split('/')[-1]: int(spec.examples_number)
self.split: {spec.path.split("/")[-1]: int(spec.examples_number)
for spec in shard_specs}
}
self.save_shard_info(shard_info)
Expand All @@ -147,9 +147,9 @@ def save_shard_info(self, shard_info: Dict[str, Dict[str, int]]) -> None:
Args:
shard_info: input shard info dict.
"""
if os.path.isfile(os.path.join(self._base_path, 'shard_info.json')):
with tf.io.gfile.GFile(os.path.join(self._base_path, 'shard_info.json'), 'r') as si_f:
if os.path.isfile(os.path.join(self._base_path, "shard_info.json")):
with tf.io.gfile.GFile(os.path.join(self._base_path, "shard_info.json"), "r") as si_f:
prev_shard_info = json.load(si_f)
shard_info.update(prev_shard_info)
with tf.io.gfile.GFile(os.path.join(self._base_path, 'shard_info.json'), 'w') as si_f:
with tf.io.gfile.GFile(os.path.join(self._base_path, "shard_info.json"), "w") as si_f:
json.dump(shard_info, si_f)
44 changes: 36 additions & 8 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,47 @@ class TestTextJsonDataset(absltest.TestCase):

def setUp(self):
self.tempdir = tempfile.mkdtemp()
_test_create_textjson_records(self.tempdir)
configs = DatasetConfigs()
configs.batch_size_train = 3
self._dataset = Dataset(self.tempdir, configs)
self.configs = DatasetConfigs()
self.configs.batch_size_train = 3

def tearDown(self):
rmtree(self.tempdir)

def test_train_fn(self):
def test_train_fn_json(self):
expected_data = _test_create_textjson_records(self.tempdir, use_two_files=False)
self._dataset = Dataset(self.tempdir, self.configs)
ds = self._dataset.train_fn('train', False)
batch = next(iter(ds))
self.assertEqual(batch['text'].shape, [3])
self.assertEqual(batch['polarity'].shape, [3])
np.array_equal(batch['polarity'].numpy(), [1, 2, 0])
self.assertEqual(list(batch['text'].numpy()),
[b'this is label file', b'this is json file', b'this is text file'])
batch_data = {}
for idx, text_val in enumerate(batch["text"].numpy()):
batch_data[idx] = {
"text": text_val.decode("utf-8"),
"label": {
"polarity": batch["polarity"].numpy()[idx],
"question": batch["question"].numpy()[idx].decode("utf-8"),
}
}
for _, value in batch_data.items():
assert value in list(expected_data.values())

def test_train_fn_two_files(self):
expected_data = _test_create_textjson_records(self.tempdir, use_two_files=True)
self.configs.batch_size_train = 6
self._dataset = Dataset(self.tempdir, self.configs)
ds = self._dataset.train_fn('train', False)
batch = next(iter(ds))
self.assertEqual(batch['text'].shape, [6])
self.assertEqual(batch['polarity'].shape, [6])
batch_data = {}
for idx, text_val in enumerate(batch["text"].numpy()):
batch_data[idx] = {
"text": text_val.decode("utf-8"),
"label": {
"polarity": batch["polarity"].numpy()[idx],
"question": batch["question"].numpy()[idx].decode("utf-8"),
}
}
for _, value in batch_data.items():
assert value in list(expected_data.values())
47 changes: 42 additions & 5 deletions tests/test_generator_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,71 @@ class TestTextJsonDatumGenerator(absltest.TestCase):

def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.data = {
self.data_1 = {
1: {
'text': 'this is text file',
'label': {
'polarity': 1
'polarity': 1,
'question': "meaning of this line?",
}
},
2: {
'text': 'this is json file',
'label': {
'polarity': 2
'polarity': 2,
'question': "meaning of this sentence?",
}
},
3: {
'text': 'this is label file',
'label': {
'polarity': 0
'polarity': 0,
'question': "meaning of this para?",
}
},
}
self.data_2 = {
4: {
'text': 'this is next text file',
'label': {
'polarity': 4,
'question': "meaning of next line?",
}
},
5: {
'text': 'this is next json file',
'label': {
'polarity': 5,
'question': "meaning of next sentence?",
}
},
6: {
'text': 'this is next label file',
'label': {
'polarity': 6,
'question': "meaning of next para?",
}
},
}
self.data = {**self.data_1, **self.data_2}
with open(os.path.join(self.tempdir, 'train.json'), 'w') as f:
json.dump(self.data, f)
json.dump(self.data_1, f)
self.gen_from_json = text.TextJsonDatumGenerator(self.tempdir)

def tearDown(self):
rmtree(self.tempdir)

def test_generate_datum(self):
for key, datum in self.gen_from_json(split='train'):
self.assertEqual(datum['text'], self.data_1[literal_eval(key)]['text'])
self.assertEqual(datum['polarity'], self.data_1[literal_eval(key)]['label']['polarity'])
self.assertEqual(datum['question'], self.data_1[literal_eval(key)]['label']['question'])

def test_generate_datum_multiple_files(self):
with open(os.path.join(self.tempdir, 'train_2.json'), 'w') as f:
json.dump(self.data_2, f)
gen_from_json = text.TextJsonDatumGenerator(self.tempdir)
for key, datum in self.gen_from_json(split='train'):
self.assertEqual(datum['text'], self.data[literal_eval(key)]['text'])
self.assertEqual(datum['polarity'], self.data[literal_eval(key)]['label']['polarity'])
self.assertEqual(datum['question'], self.data[literal_eval(key)]['label']['question'])
44 changes: 39 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,68 @@ def _test_create_seg_records(path):
writer.create_records()


def _test_create_textjson_records(path):
def _test_create_textjson_records(path, use_two_files=False):
tempdir = tempfile.mkdtemp()
data = {
1: {
'text': 'this is text file',
'label': {
'polarity': 1
'polarity': 1,
'question': "meaning of this line?",
}
},
2: {
'text': 'this is json file',
'label': {
'polarity': 2
'polarity': 2,
'question': "meaning of this sentence?",
}
},
3: {
'text': 'this is label file',
'label': {
'polarity': 0
'polarity': 0,
'question': "meaning of this para?",
}
},
}
num_examples = 3
final_data = data
if use_two_files:
data_2 = {
4: {
'text': 'this is next text file',
'label': {
'polarity': 4,
'question': "meaning of next line?",
}
},
5: {
'text': 'this is next json file',
'label': {
'polarity': 5,
'question': "meaning of next sentence?",
}
},
6: {
'text': 'this is next label file',
'label': {
'polarity': 6,
'question': "meaning of next para?",
}
},
}
with open(os.path.join(tempdir, 'train_2.json'), 'w') as f:
json.dump(data_2, f)
num_examples = 6
final_data = {**data, **data_2}
with open(os.path.join(tempdir, 'train.json'), 'w') as f:
json.dump(data, f)
gen_from_json = text.TextJsonDatumGenerator(tempdir)
serializer = DatumSerializer('text')
Path(path).mkdir(parents=True, exist_ok=True)
textjson_gen = text.TextJsonDatumGenerator(tempdir)
writer = TFRecordWriter(textjson_gen, serializer, path, 'train', 3)
writer = TFRecordWriter(textjson_gen, serializer, path, 'train', num_examples)
writer.create_records()
rmtree(tempdir)
return final_data

0 comments on commit 4f196bc

Please sign in to comment.