Skip to content

Commit

Permalink
refactor: export_areas_as_sql
Browse files Browse the repository at this point in the history
  • Loading branch information
joaquimds committed Dec 17, 2024
1 parent 25a2fb0 commit 1ed9ba4
Showing 1 changed file with 81 additions and 45 deletions.
126 changes: 81 additions & 45 deletions hub/management/commands/export_areas_as_sql.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import json
from pathlib import Path

Check failure on line 1 in hub/management/commands/export_areas_as_sql.py

View workflow job for this annotation

GitHub Actions / lint-python

Imports are incorrectly sorted and/or formatted.
from typing import Any
from django.conf import settings
from dataclasses import dataclass, field
from psycopg import Column

# from django postgis
from django.core.management.base import BaseCommand
from pathlib import Path

from django.conf import settings
from django.db import connection

QUERIES = {
"hub_areatype": "SELECT * FROM hub_areatype",
# Get the hub_areatype.code column for matching areas to area types in the output SQL, as IDs will have changed
"hub_area": "SELECT *,(SELECT code FROM hub_areatype WHERE hub_areatype.id = area_type_id) as area_type_code FROM hub_area",
}

@dataclass
class TableConfig:
table_name: str
extra_select_columns: dict[str, str] = field(default_factory=dict)
output_column_templates: dict[str, str] = field(default_factory=dict)


TABLES = [
TableConfig(table_name="hub_areatype"),
TableConfig(
table_name="hub_area",
extra_select_columns={
"area_type_code": "(SELECT code FROM hub_areatype WHERE hub_areatype.id = hub_area.area_type_id)"
},
output_column_templates={
"area_type_id": "(SELECT id FROM hub_areatype WHERE code = '{area_type_code}')"
},
),
]


class Command(BaseCommand):
Expand All @@ -20,46 +36,66 @@ class Command(BaseCommand):
without causing primary key conflicts.
"""

def escape_sql_string(self, value):
if value is None:
return "NULL"
# Replace single quote with two single quotes
return str(value).replace("'", "''")

def handle(self, *args, **options):
print("Exporting areas and area types from current database to data/areas.psql")
count = 0
output_file: Path = settings.BASE_DIR / "data" / "areas.psql"
with output_file.open("w", encoding="utf8") as f:
for table_name, table_query in QUERIES.items():
with connection.cursor() as cursor:
cursor.execute(f"{table_query} ORDER BY id ASC")
rows = cursor.fetchall()
# The [1:] indices on cursor.description and row are to drop the ID column, which is first
columns = [description for description in cursor.description[1:]]
column_names = ",".join(
[
description.name
for description in cursor.description[1:]
if description.name != "area_type_code" # Skip outputting derived column
]
for table_config in TABLES:
rows, columns = self.do_query(table_config)
for row in rows:
output_record = self.get_output_record(row, columns, table_config)
column_names = ",".join(output_record.keys())
output_values = ",".join(output_record.values())
f.write(
f"INSERT INTO {table_config.table_name} ({column_names}) VALUES ({output_values});\n"
)

for row in rows:
output_row = []
for i, value in enumerate(row[1:]):
column = columns[i]
if column.name == "area_type_code":
continue # Skip outputting derived column
if column.name == "area_type_id":
area_type_code = row[-1]
value = f"(SELECT id FROM hub_areatype WHERE code = '{area_type_code}')"
else:
# output the value as a string, cast to the correct type
value = f"'{self.escape_sql_string(value)}'::{column.type_display}"
output_row.append(value)
f.write(
f"INSERT INTO {table_name} ({column_names}) VALUES ({','.join(output_row)});\n"
)
count += 1
count += 1
print(f"Exported {count} rows to data/areas.psql")

def do_query(
self, table_config: TableConfig
) -> tuple[list[tuple[Any]], tuple[Column]]:
with connection.cursor() as cursor:
select = "*"
for column_name, column_select in table_config.extra_select_columns.items():
select += f", {column_select} as {column_name}"
cursor.execute(
f"SELECT {select} FROM {table_config.table_name} ORDER BY id ASC"
)
rows = cursor.fetchall()
columns = cursor.description
return (rows, columns)

def escape_sql_string(self, value: Any) -> str:
if value is None:
return "NULL"
# Replace single quote with two single quotes
return str(value).replace("'", "''")

def get_output_record(
self, row: tuple[Any], columns: tuple[Column], table_config: TableConfig
) -> dict[str, str]:
record = {}
for i, column in enumerate(columns):
# Don't output ID columns or extra select columns (these can't be imported)
if column.name == "id" or column.name in table_config.extra_select_columns:
continue

if column.name in table_config.output_column_templates:
template = table_config.output_column_templates[column.name]
value = self.template_output_value(template, row, columns)
else:
# output the value as a string, cast to the correct type in postgres
value = f"'{self.escape_sql_string(row[i])}'::{column.type_display}"

record[column.name] = value
return record

def template_output_value(
self, template: str, row: tuple[Any], columns: tuple[Column]
) -> str:
value = template
for i, column in enumerate(columns):
value = value.replace("{" + column.name + "}", str(row[i]))
return value

0 comments on commit 1ed9ba4

Please sign in to comment.