Skip to content

Commit

Permalink
Add FK following to determine relevant rows
Browse files Browse the repository at this point in the history
  • Loading branch information
Andre Senna committed May 24, 2023
1 parent 1fa828f commit f28b41a
Showing 1 changed file with 89 additions and 5 deletions.
94 changes: 89 additions & 5 deletions flybase2metta/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
USE_PRECOMPUTED_NEAR_MATCHES = False
PRINT_PRECOMPUTED_NEAR_MATCHES = False
SKIP_SQL_JOIN = False
SKIP_FKEY_FOLLOWING = SKIP_SQL_JOIN or False

def _file_line_count(file_name):
output = subprocess.run(["wc", "-l", file_name], stdout=subprocess.PIPE)
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(self, sql_file_name, precomputed = None):
self.all_precomputed_nodes = set()
self.all_precomputed_node_names = set()
self.log_precomputed_nodes = None
self.relevant_fkeys = {}

Path(self.target_dir).mkdir(parents=True, exist_ok=True)
for filename in os.listdir(self.target_dir):
Expand Down Expand Up @@ -193,7 +195,7 @@ def _emit_precomputed_tables(self, output_file):
self.log_precomputed_nodes = True
table_count = 0
for table in self.precomputed.all_tables:
self._print_progress_bar(table_count, len(self.precomputed.all_tables), 50, 3, 4)
self._print_progress_bar(table_count, len(self.precomputed.all_tables), 50, 3, 5)
#print(table)
for row in table.rows:
#print(row)
Expand All @@ -219,7 +221,7 @@ def _emit_precomputed_tables(self, output_file):
node2 = self._add_node(AtomTypes.VERBATIM, value2)
self._add_execution(schema, node1, node2)
table_count += 1
self._print_progress_bar(table_count, len(self.precomputed.all_tables), 50, 3, 4)
self._print_progress_bar(table_count, len(self.precomputed.all_tables), 50, 3, 5)
self.log_precomputed_nodes = False

def _checkpoint(self, create_new, use_precomputed_filter=False):
Expand Down Expand Up @@ -381,6 +383,44 @@ def _new_row_precomputed(self, line):
if (not non_mapped_column(name)) and (name not in fkeys):
self.precomputed.check_field_value(self.current_table, name, value)

def _new_row_relevant_fkeys(self, line):
if SCHEMA_ONLY or (self.current_table not in self.relevant_fkeys):
return
table = self.table_schema[self.current_table]
table_short_name = short_name(self.current_table)
pkey = table['primary_key']
fkeys = table['foreign_keys']
assert pkey,f"self.current_table = {self.current_table} pkey = {pkey} \n{table}"
data = line.split("\t")
if len(self.current_table_header) != len(data):
self._error(f"Invalid row at line {self.line_count} Table: {self.current_table} Header: {self.current_table_header} Raw line: <{line}>")
return
pkey_node = None
for name, value in zip(self.current_table_header, data):
if name == pkey:
if value not in self.relevant_fkeys[self.current_table]:
return
pkey_node = self._add_node(table_short_name, value)
break
assert pkey_node is not None
for name, value in zip(self.current_table_header, data):
if non_mapped_column(name):
continue
if name in fkeys:
referenced_table, referenced_field = table['foreign_key'][name]
predicate_node = self._add_node(AtomTypes.PREDICATE, referenced_table)
fkey_node = self._add_node(AtomTypes.CONCEPT, _compose_name(referenced_table, value))
self._add_evaluation(predicate_node, pkey_node, fkey_node)
elif name != pkey:
ftype = self.current_field_types.get(name, None)
if not ftype:
continue
value_node = self._add_value_node(table_short_name, ftype, value)
if not value_node:
continue
schema_node = self._add_node(AtomTypes.SCHEMA, _compose_name(table_short_name, name))
self._add_execution(schema_node, pkey_node, value_node)

def _new_row(self, line):
if SCHEMA_ONLY or (self.relevant_tables is not None and self.current_table not in self.relevant_tables):
return
Expand Down Expand Up @@ -418,6 +458,10 @@ def _new_row(self, line):
continue
if name in fkeys:
referenced_table, referenced_field = table['foreign_key'][name]
if not SKIP_FKEY_FOLLOWING:
if referenced_table not in self.relevant_fkeys:
self.relevant_fkeys[referenced_table] = set()
self.relevant_fkeys[referenced_table].add(value)
predicate_node = self._add_node(AtomTypes.PREDICATE, referenced_table)
fkey_node = self._add_node(AtomTypes.CONCEPT, _compose_name(referenced_table, value))
self._add_evaluation(predicate_node, pkey_node, fkey_node)
Expand Down Expand Up @@ -468,7 +512,7 @@ def _parse_step_1(self):
while line:
self.line_count += 1
if SHOW_PROGRESS:
self._print_progress_bar(self.line_count, file_size, 50, 1, 4 if self.precomputed else 2)
self._print_progress_bar(self.line_count, file_size, 50, 1, 5)
line = line.replace('\n', '').strip()
if state == State.WAIT_KNOWN_COMMAND:
if line.startswith(CREATE_TABLE_PREFIX):
Expand Down Expand Up @@ -508,7 +552,7 @@ def _parse_step_2(self):
while line:
self.line_count += 1
if SHOW_PROGRESS:
self._print_progress_bar(self.line_count, file_size, 50, 2, 4)
self._print_progress_bar(self.line_count, file_size, 50, 2, 5)
if not self.precomputed.all_tables_mapped():
line = line.replace('\n', '').strip()
if state == State.WAIT_KNOWN_COMMAND:
Expand Down Expand Up @@ -552,7 +596,7 @@ def _parse_step_3(self):
if self.expression_chunk_count >= EXPRESSIONS_PER_CHUNK:
self._checkpoint(True)
if SHOW_PROGRESS:
self._print_progress_bar(self.line_count, file_size, 50, 4 if self.precomputed else 2, 4 if self.precomputed else 2)
self._print_progress_bar(self.line_count, file_size, 50, 4, 5)
line = line.replace('\n', '').strip()
if state == State.WAIT_KNOWN_COMMAND:
if line.startswith(COPY_PREFIX):
Expand All @@ -570,6 +614,44 @@ def _parse_step_3(self):
line = file.readline()
self._checkpoint(False)

def _parse_step_4(self):

text = ""
self.line_count = 0
file_size = FILE_SIZE

if not self.precomputed:
for key,table in self.table_schema.items():
if not table['primary_key']:
self.discarded_tables.append(key)
self._error(f"Discarded table {key}. No PRIMARY KEY defined.")

state = State.WAIT_KNOWN_COMMAND
with open(self.sql_file_name, 'r') as file:
line = file.readline()
while line:
self.line_count += 1
if self.expression_chunk_count >= EXPRESSIONS_PER_CHUNK:
self._checkpoint(True)
if SHOW_PROGRESS:
self._print_progress_bar(self.line_count, file_size, 50, 5, 5)
line = line.replace('\n', '').strip()
if state == State.WAIT_KNOWN_COMMAND:
if line.startswith(COPY_PREFIX):
if self._start_copy(line):
state = State.READING_COPY
elif state == State.READING_COPY:
if line.startswith(COPY_SUFFIX):
state = State.WAIT_KNOWN_COMMAND
else:
if not SKIP_SQL_JOIN:
self._new_row_relevant_fkeys(line)
else:
print(f"Invalid state {state}")
assert False
line = file.readline()
self._checkpoint(False)

def parse(self):
self._setup()
self._parse_step_1()
Expand All @@ -585,6 +667,8 @@ def parse(self):
f.write("\n\n")
f.close()
self._parse_step_3()
if not SKIP_FKEY_FOLLOWING:
self._parse_step_4()
if self.errors:
print(f"Errors occured while processing this SQL file. See them in {self.error_file_name}")
self._tear_down()
Expand Down

0 comments on commit f28b41a

Please sign in to comment.