Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
Reuse connection
Browse files Browse the repository at this point in the history
Tap-postgres opens a new connection every time it needs to cast a value.
This is highly inefficient as opening a connection is usually a slow and
resource-intensive operation. An easy fix would be to use something like
PgBouncer, but it's even better if we open just once connection and
reuse it for all queries.

To fix the issue we do two things:
1. We created a Singleton Postgres connection wrapper. This wrapper
   actually holds up to two connections, since we need two different
   connection factories. The `connect` method returns the conneciton we
   need based on the arguments provided.
2. Remove `when` statements when asking for a connection. When
   statements are great everytime we need to ensure a resource is
   properly closed after it's being used, but in our specific case, we
   don't want to close connections after each query.
  • Loading branch information
ivanovyordan committed Aug 30, 2021
1 parent 11914d4 commit c55e952
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 359 deletions.
108 changes: 54 additions & 54 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def do_discovery(conn_config):
Returns: list of discovered streams
"""
with post_db.open_connection(conn_config) as conn:
LOGGER.info("Discovering db %s", conn_config['dbname'])
streams = discover_db(conn, conn_config.get('filter_schemas'))
conn = post_db.open_connection()
LOGGER.info("Discovering db %s", conn_config['dbname'])
streams = discover_db(conn, conn_config.get('filter_schemas'))

if len(streams) == 0:
raise RuntimeError('0 tables were discovered across the entire cluster')
Expand All @@ -50,21 +50,21 @@ def do_discovery(conn_config):
return streams


def do_sync_full_table(conn_config, stream, state, desired_columns, md_map):
def do_sync_full_table(stream, state, desired_columns, md_map):
"""
Runs full table sync
"""
LOGGER.info("Stream %s is using full_table replication", stream['tap_stream_id'])
sync_common.send_schema_message(stream, [])
if md_map.get((), {}).get('is-view'):
state = full_table.sync_view(conn_config, stream, state, desired_columns, md_map)
state = full_table.sync_view(stream, state, desired_columns, md_map)
else:
state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map)
state = full_table.sync_table(stream, state, desired_columns, md_map)
return state


# Possible state keys: replication_key, replication_key_value, version
def do_sync_incremental(conn_config, stream, state, desired_columns, md_map):
def do_sync_incremental(stream, state, desired_columns, md_map):
"""
Runs Incremental sync
"""
Expand All @@ -82,7 +82,7 @@ def do_sync_incremental(conn_config, stream, state, desired_columns, md_map):
state = singer.write_bookmark(state, stream['tap_stream_id'], 'replication_key', replication_key)

sync_common.send_schema_message(stream, [replication_key])
state = incremental.sync_table(conn_config, stream, state, desired_columns, md_map)
state = incremental.sync_table(stream, state, desired_columns, md_map)

return state

Expand Down Expand Up @@ -164,27 +164,27 @@ def sync_traditional_stream(conn_config, stream, state, sync_method, end_lsn):
LOGGER.warning('There are no columns selected for stream %s, skipping it', stream['tap_stream_id'])
return state

register_type_adapters(conn_config)
register_type_adapters()

if sync_method == 'full':
state = singer.set_currently_syncing(state, stream['tap_stream_id'])
state = do_sync_full_table(conn_config, stream, state, desired_columns, md_map)
state = do_sync_full_table(stream, state, desired_columns, md_map)
elif sync_method == 'incremental':
state = singer.set_currently_syncing(state, stream['tap_stream_id'])
state = do_sync_incremental(conn_config, stream, state, desired_columns, md_map)
state = do_sync_incremental(stream, state, desired_columns, md_map)
elif sync_method == 'logical_initial':
state = singer.set_currently_syncing(state, stream['tap_stream_id'])
LOGGER.info("Performing initial full table sync")
state = singer.write_bookmark(state, stream['tap_stream_id'], 'lsn', end_lsn)

sync_common.send_schema_message(stream, [])
state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map)
state = full_table.sync_table(stream, state, desired_columns, md_map)
state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', None)
elif sync_method == 'logical_initial_interrupted':
state = singer.set_currently_syncing(state, stream['tap_stream_id'])
LOGGER.info("Initial stage of full table sync was interrupted. resuming...")
sync_common.send_schema_message(stream, [])
state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map)
state = full_table.sync_table(stream, state, desired_columns, md_map)
else:
raise Exception("unknown sync method {} for stream {}".format(sync_method, stream['tap_stream_id']))

Expand Down Expand Up @@ -222,53 +222,53 @@ def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_fil
return state


def register_type_adapters(conn_config):
def register_type_adapters():
"""
//todo doc needed
"""
with post_db.open_connection(conn_config) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# citext[]
cur.execute("SELECT typarray FROM pg_type where typname = 'citext'")
citext_array_oid = cur.fetchone()
if citext_array_oid:
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(citext_array_oid[0],), 'CITEXT[]', psycopg2.STRING))

# bit[]
cur.execute("SELECT typarray FROM pg_type where typname = 'bit'")
bit_array_oid = cur.fetchone()[0]
conn = post_db.open_connection()
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# citext[]
cur.execute("SELECT typarray FROM pg_type where typname = 'citext'")
citext_array_oid = cur.fetchone()
if citext_array_oid:
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(bit_array_oid,), 'BIT[]', psycopg2.STRING))

# UUID[]
cur.execute("SELECT typarray FROM pg_type where typname = 'uuid'")
uuid_array_oid = cur.fetchone()[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(uuid_array_oid,), 'UUID[]', psycopg2.STRING))

# money[]
cur.execute("SELECT typarray FROM pg_type where typname = 'money'")
money_array_oid = cur.fetchone()[0]
(citext_array_oid[0],), 'CITEXT[]', psycopg2.STRING))

# bit[]
cur.execute("SELECT typarray FROM pg_type where typname = 'bit'")
bit_array_oid = cur.fetchone()[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(bit_array_oid,), 'BIT[]', psycopg2.STRING))

# UUID[]
cur.execute("SELECT typarray FROM pg_type where typname = 'uuid'")
uuid_array_oid = cur.fetchone()[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(uuid_array_oid,), 'UUID[]', psycopg2.STRING))

# money[]
cur.execute("SELECT typarray FROM pg_type where typname = 'money'")
money_array_oid = cur.fetchone()[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(money_array_oid,), 'MONEY[]', psycopg2.STRING))

# json and jsonb
# pylint: disable=unnecessary-lambda
psycopg2.extras.register_default_json(loads=lambda x: str(x))
psycopg2.extras.register_default_jsonb(loads=lambda x: str(x))

# enum[]'s
cur.execute("SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid")
for oid in cur.fetchall():
enum_oid = oid[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(money_array_oid,), 'MONEY[]', psycopg2.STRING))

# json and jsonb
# pylint: disable=unnecessary-lambda
psycopg2.extras.register_default_json(loads=lambda x: str(x))
psycopg2.extras.register_default_jsonb(loads=lambda x: str(x))

# enum[]'s
cur.execute("SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid")
for oid in cur.fetchall():
enum_oid = oid[0]
psycopg2.extensions.register_type(
psycopg2.extensions.new_array_type(
(enum_oid,), 'ENUM_{}[]'.format(enum_oid), psycopg2.STRING))
(enum_oid,), 'ENUM_{}[]'.format(enum_oid), psycopg2.STRING))


def do_sync(conn_config, catalog, default_replication_method, state, state_file=None):
Expand All @@ -281,7 +281,7 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file=
LOGGER.info("Selected streams: %s ", [s['tap_stream_id'] for s in streams])
if any_logical_streams(streams, default_replication_method):
# Use of logical replication requires fetching an lsn
end_lsn = logical_replication.fetch_current_lsn(conn_config)
end_lsn = logical_replication.fetch_current_lsn()
LOGGER.debug("end_lsn = %s ", end_lsn)
else:
end_lsn = None
Expand Down
39 changes: 12 additions & 27 deletions tap_postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import List
from dateutil.parser import parse

from tap_postgres.postgres import Postgres

LOGGER = singer.get_logger('tap_postgres')

CURSOR_ITER_SIZE = 20000
Expand Down Expand Up @@ -38,26 +40,9 @@ def fully_qualified_table_name(schema, table):
return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table))


def open_connection(conn_config, logical_replication=False):
cfg = {
'application_name': 'pipelinewise',
'host': conn_config['host'],
'dbname': conn_config['dbname'],
'user': conn_config['user'],
'password': conn_config['password'],
'port': conn_config['port'],
'connect_timeout': 30
}

if conn_config.get('sslmode'):
cfg['sslmode'] = conn_config['sslmode']

if logical_replication:
cfg['connection_factory'] = psycopg2.extras.LogicalReplicationConnection
def open_connection(logical_replication=False):
return Postgres.get_instance().connect(logical_replication)

conn = psycopg2.connect(**cfg)

return conn

def prepare_columns_for_select_sql(c, md_map):
column_name = ' "{}" '.format(canonicalize_identifier(c))
Expand Down Expand Up @@ -191,14 +176,14 @@ def selected_row_to_singer_message(stream, row, version, columns, time_extracted
time_extracted=time_extracted)


def hstore_available(conn_info):
with open_connection(conn_info) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur:
cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """)
res = cur.fetchone()
if res and res[0]:
return True
return False
def hstore_available():
conn = open_connection()
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur:
cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """)
res = cur.fetchone()
if res and res[0]:
return True
return False


def compute_tap_stream_id(schema_name, table_name):
Expand Down
56 changes: 56 additions & 0 deletions tap_postgres/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from singer import utils

import psycopg2
import psycopg2.extras


# pylint: disable=missing-class-docstring,missing-function-docstring
class Postgres:
__instance = None

@staticmethod
def get_instance():
if Postgres.__instance is None:
Postgres()

return Postgres.__instance

@staticmethod
def get_configuration(logical_replication):
args = utils.parse_args({})
conn_config = args.config

cfg = {
'application_name': 'pipelinewise',
'host': conn_config['host'],
'dbname': conn_config['dbname'],
'user': conn_config['user'],
'password': conn_config['password'],
'port': conn_config['port'],
'connect_timeout': 30
}

if conn_config.get('sslmode'):
cfg['sslmode'] = conn_config['sslmode']

if logical_replication:
cfg['connection_factory'] = psycopg2.extras.LogicalReplicationConnection

return cfg


def __init__(self):
if Postgres.__instance is not None:
raise Exception("This class is a singleton!")

Postgres.__instance = self
self.connections = {"logical": None, "transactional": None}

def connect(self, logical_replication=False):
connection_type = "logical" if logical_replication else "transactional"

if not self.connections[connection_type] or self.connections[connection_type].closed:
config = Postgres.get_configuration(logical_replication)
self.connections[connection_type] = psycopg2.connect(**config)

return self.connections[connection_type]
54 changes: 27 additions & 27 deletions tap_postgres/stream_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,33 @@ def refresh_streams_schema(conn_config: Dict, streams: List[Dict]):
LOGGER.debug('Current streams schemas %s', streams)

# Run discovery to get the streams most up to date json schemas
with open_connection(conn_config) as conn:
new_discovery = {
stream['tap_stream_id']: stream
for stream in discover_db(conn, conn_config.get('filter_schemas'), [st['table_name'] for st in streams])
}

LOGGER.debug('New discovery schemas %s', new_discovery)

# For every stream dictionary, update the schema and metadata from the new discovery
for idx, stream in enumerate(streams):
# update schema
streams[idx]['schema'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['schema'])

# Update metadata
#
# 1st step: new discovery doesn't contain non-discoverable metadata: e.g replication method & key, selected
# so let's copy those from the original stream object
md_map = metadata.to_map(stream['metadata'])
meta = md_map.get(())

for idx_met, metadatum in enumerate(new_discovery[stream['tap_stream_id']]['metadata']):
if not metadatum['breadcrumb']:
meta.update(new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'])
new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'] = meta

# 2nd step: now copy all the metadata from the updated new discovery to the original stream
streams[idx]['metadata'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['metadata'])
conn = open_connection()
new_discovery = {
stream['tap_stream_id']: stream
for stream in discover_db(conn, conn_config.get('filter_schemas'), [st['table_name'] for st in streams])
}

LOGGER.debug('New discovery schemas %s', new_discovery)

# For every stream dictionary, update the schema and metadata from the new discovery
for idx, stream in enumerate(streams):
# update schema
streams[idx]['schema'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['schema'])

# Update metadata
#
# 1st step: new discovery doesn't contain non-discoverable metadata: e.g replication method & key, selected
# so let's copy those from the original stream object
md_map = metadata.to_map(stream['metadata'])
meta = md_map.get(())

for idx_met, metadatum in enumerate(new_discovery[stream['tap_stream_id']]['metadata']):
if not metadatum['breadcrumb']:
meta.update(new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'])
new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'] = meta

# 2nd step: now copy all the metadata from the updated new discovery to the original stream
streams[idx]['metadata'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['metadata'])

LOGGER.debug('Updated streams schemas %s', streams)

Expand Down
Loading

0 comments on commit c55e952

Please sign in to comment.