forked from jongminyoon/fever
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_db.py
114 lines (88 loc) · 3.87 KB
/
build_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Adapted and modified from https://github.com/sheffieldnlp/fever-baselines/tree/master/src/scripts
# which is adapted from https://github.com/facebookresearch/DrQA/blob/master/scripts/retriever/build_db.py
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
"""A script to read in and store documents in a sqlite database."""
import argparse
import sqlite3
import json
import os
from multiprocessing import Pool as ProcessPool
from tqdm import tqdm
import utils
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Build db")
# TODO add time for logging
# ------------------------------------------------------------------------------
# Store corpus.
# ------------------------------------------------------------------------------
def get_contents(filename):
"""Parse the contents of a file. Each line is a JSON encoded document."""
documents = []
with open(filename) as f:
for line in f:
# Parse document
doc = json.loads(line)
# Skip if it is empty or None
if not doc:
continue
# Add the document
documents.append((utils.normalize(doc['id']), doc['text']))
return documents
def store_contents(data_path, save_path, num_workers=4, num_files = 5):
"""Preprocess and store a corpus of documents in sqlite.
Args:
data_path: Root path to directory (or directory of directories) of files
containing json encoded documents (must have `id` and `text` fields).
save_path: Path to output sqlite db.
num_workers: Number of parallel processes to use when reading docs.
num_files: Split db in to num_files files.
"""
logger.info('Reading into database...')
files = [f for f in utils.iter_files(data_path)]
if num_files == 1:
filelist = [files]
else:
one_length = len(files) // num_files + 1
filelist = [[files[i*one_length+j] for j in range(one_length)] for i in range(num_files-1)]
filelist.append(files[one_length*(num_files-1):])
for i, files in enumerate(filelist):
logger.info('Building %i-th db...' % i)
temp_save_path = os.path.join(save_path, 'fever%i.db' % i)
if os.path.isfile(temp_save_path):
raise RuntimeError('%s already exists! Not overwriting.' % temp_save_path)
conn = sqlite3.connect(temp_save_path)
c = conn.cursor()
c.execute("CREATE TABLE documents (id PRIMARY KEY, text);")
workers = ProcessPool(num_workers)
count = 0
with tqdm(total=len(files)) as pbar:
for pairs in tqdm(workers.imap_unordered(get_contents, files)):
count += len(pairs)
c.executemany("INSERT INTO documents VALUES (?,?)", pairs)
pbar.update()
logger.info('Read %d docs.' % count)
logger.info('Committing...')
conn.commit()
conn.close()
# ------------------------------------------------------------------------------
# Main.
# ------------------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('data_path', type=str, help='path/to/data')
parser.add_argument('save_path', type=str, help='path/to/saved')
parser.add_argument('--num-workers', type=int, default=None,
help='Number of CPU processes (for tokenizing, etc)')
parser.add_argument('--num-files', type=int, default=None,
help='Number of db files')
args = parser.parse_args()
save_dir = args.save_path
if not os.path.exists(save_dir):
logger.info("Save directory doesn't exist. Making {0}".format(save_dir))
os.makedirs(save_dir)
store_contents(
args.data_path, args.save_path, args.num_workers, args.num_files
)
# python build_db.py data/wiki-pages data/fever