Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bulk operations utilities #224

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/src/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ["bulk_ops"]
143 changes: 143 additions & 0 deletions app/src/db/bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Bulk database operations for performance.

Provides a bulk_upsert function
"""
KevinJBoyer marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Sequence

import psycopg
lorenyu marked this conversation as resolved.
Show resolved Hide resolved
from psycopg import rows, sql

Connection = psycopg.Connection
Cursor = psycopg.Cursor
kwargs_row = rows.kwargs_row


def bulk_upsert(
cur: psycopg.Cursor,
table: str,
attributes: Sequence[str],
objects: Sequence[Any],
constraint: str,
update_condition: sql.SQL = sql.SQL(""),

Check warning on line 21 in app/src/db/bulk_ops.py

View workflow job for this annotation

GitHub Actions / Lint

src/db/bulk_ops.py:21:33: B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
):
"""Bulk insert or update a sequence of objects.

Insert a sequence of objects, or update on conflict.
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.

Args:
cur: the Cursor object from the pyscopg library
table: the name of the table to insert into or update
attributes: a sequence of attribute names to copy from each object
objects: a sequence of objects to upsert
constraint: the table unique constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
temp_table = f"temp_{table}"
create_temp_table(cur, temp_table=temp_table, src_table=table)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a very niche edge case, but what would happen if two temp tables were created with the same name by different processes? Does that cause any issues, or does them being in the transactions entirely shield them?

Copy link
Contributor Author

@KevinJBoyer KevinJBoyer Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question, I tested it locally and it looks like the transaction isolation works like you'd expect. Here's the SQL I ran:

CREATE TEMP TABLE test (id INT) ON COMMIT DROP;
SELECT * FROM test;

-- In a separate connection!
BEGIN;
CREATE TEMP TABLE test (other INT) ON COMMIT DROP;
SELECT * FROM test;
COMMIT;

-- Back in the original connection
COMMIT;

bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
write_from_table_to_table(
cur,
src_table=temp_table,
dest_table=table,
columns=attributes,
constraint=constraint,
update_condition=update_condition,
)


def create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str):
"""
Create table that lives only for the current transaction.
Use an existing table to determine the table structure.
Once the transaction is committed the temp table will be deleted.
Args:
temp_table: the name of the temporary table to create
src_table: the name of the existing table
"""
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {temp_table}\
(LIKE {src_table})\
ON COMMIT DROP"
).format(
temp_table=sql.Identifier(temp_table),
src_table=sql.Identifier(src_table),
)
)


def bulk_insert(
cur: psycopg.Cursor,
table: str,
columns: Sequence[str],
objects: Sequence[Any],
):
"""
Write data from a sequence of objects to a temp table.
This function uses the PostgreSQL COPY command which is highly performant.
Args:
cur: the Cursor object from the pyscopg library
table: the name of the temporary table
columns: a sequence of column names that are attributes of each object
objects: a sequence of objects with attributes defined by columns
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format(
table=sql.Identifier(table),
columns=columns_sql,
)
with cur.copy(query) as copy:
for obj in objects:
values = [getattr(obj, column) for column in columns]
copy.write_row(values)


def write_from_table_to_table(
cur: psycopg.Cursor,
src_table: str,
dest_table: str,
columns: Sequence[str],
constraint: str,
update_condition: sql.SQL = sql.SQL(""),

Check warning on line 104 in app/src/db/bulk_ops.py

View workflow job for this annotation

GitHub Actions / Lint

src/db/bulk_ops.py:104:33: B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
):
"""
Write data from one table to another.
If there are conflicts due to unique constraints, overwrite existing data.
Args:
cur: the Cursor object from the pyscopg library
src_table: the name of the table that will be copied from
dest_table: the name of the table that will be written to
columns: a sequence of column names to copy over
constraint: the arbiter constraint to use to determine conflicts
update_condition: optional WHERE clause to limit updates for a
conflicting row
"""
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
update_sql = sql.SQL(",").join(
[
sql.SQL("{column} = EXCLUDED.{column}").format(
column=sql.Identifier(column),
)
for column in columns
]
)
query = sql.SQL(
"INSERT INTO {dest_table}({columns})\
SELECT {columns} FROM {src_table}\
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\
{update_condition}"
).format(
dest_table=sql.Identifier(dest_table),
columns=columns_sql,
src_table=sql.Identifier(src_table),
constraint=sql.Identifier(constraint),
update_sql=update_sql,
update_condition=update_condition,
)
cur.execute(query)


__all__ = ["bulk_upsert"]
89 changes: 89 additions & 0 deletions app/tests/src/db/test_bulk_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Tests for bulk_ops module"""
import operator
import random
from dataclasses import dataclass

from psycopg import rows, sql

import src.adapters.db as db
from src.db import bulk_ops


def test_bulk_upsert(db_session: db.Session):
conn = db_session.connection().connection
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to know if there's a better way of doing this. I also considered:

    db_client = db.PostgresDBClient()
    conn = db_client._engine.raw_connection()

but accessing _engine directly did not feel appropriate (and doesn't solve for the type issue in any case)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, not sure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider adding a raw_connection() method to the client class which does what you suggested. For the docs, mention that unless you're trying to do something very low level (ie. in psycopg) you'll almost never actually want to use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this and included a comment with that context -- LMK what you think!

table = "temp_table"
attributes = ["id", "num"]
objects = [get_random_number_object() for i in range(100)]
constraint = "temp_table_pkey"

# Create a table for testing bulk upsert
cur.execute(
sql.SQL(
"CREATE TEMP TABLE {table}"
"("
"id TEXT NOT NULL,"
"num INT,"
"CONSTRAINT {constraint} PRIMARY KEY (id)"
")"
).format(
table=sql.Identifier(table),
constraint=sql.Identifier(constraint),
)
)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()

# Check that all the objects were inserted
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects

# Now modify half of the objects
for obj in objects[: int(len(objects) / 2)]:
obj.num = random.randint(1, 10000)

bulk_ops.bulk_upsert(
cur,
table,
attributes,
objects,
constraint,
)
conn.commit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it'd be nice to have the test case do a combination of inserts and updates rather than just inserts and updates separately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added -- one round of inserts, then a second round of combo insert + updates


# Check that the objects were updated
cur.execute(
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
table=sql.Identifier(table)
)
)
records = cur.fetchall()
objects.sort(key=operator.attrgetter("id"))
assert records == objects


@dataclass
class Number:
id: str
num: int


def get_random_number_object() -> Number:
return Number(
id=str(random.randint(1000000, 9999999)),
num=random.randint(1, 10000),
)
KevinJBoyer marked this conversation as resolved.
Show resolved Hide resolved
Loading