diff --git a/tap_mysql/connection.py b/tap_mysql/connection.py index b49b762..616fe68 100644 --- a/tap_mysql/connection.py +++ b/tap_mysql/connection.py @@ -11,11 +11,14 @@ LOGGER = singer.get_logger('tap_mysql') CONNECT_TIMEOUT_SECONDS = 30 -READ_TIMEOUT_SECONDS = 3600 # We need to hold onto this for self-signed SSL MATCH_HOSTNAME = ssl.match_hostname +DEFAULT_SESSION_SQLS = ['SET @@session.time_zone="+0:00"', + 'SET @@session.wait_timeout=28800', + 'SET @@session.net_read_timeout=3600', + 'SET @@session.innodb_lock_wait_timeout=3600'] @backoff.on_exception(backoff.expo, (pymysql.err.OperationalError), @@ -23,38 +26,31 @@ factor=2) def connect_with_backoff(connection): connection.connect() + run_session_sqls(connection) + + return connection + + +def run_session_sqls(connection): + session_sqls = connection.session_sqls warnings = [] - with connection.cursor() as cur: - try: - cur.execute('SET @@session.time_zone="+0:00"') - except pymysql.err.InternalError as exc: - warnings.append(f'Could not set session.time_zone. Error: ({exc.args[0]}) {exc.args[1]}') - - try: - cur.execute('SET @@session.wait_timeout=28800') - except pymysql.err.InternalError as exc: - warnings.append(f'Could not set session.wait_timeout. Error: ({exc.args[0]}) {exc.args[1]}') - - try: - cur.execute(f"SET @@session.net_read_timeout={READ_TIMEOUT_SECONDS}") - except pymysql.err.InternalError as exc: - warnings.append(f'Could not set session.net_read_timeout. Error: ({exc.args[0]}) {exc.args[1]}') - - try: - cur.execute('SET @@session.innodb_lock_wait_timeout=3600') - except pymysql.err.InternalError as exc: - warnings.append( - f'Could not set session.innodb_lock_wait_timeout. Error: ({exc.args[0]}) {exc.args[1]}' - ) - - if warnings: - LOGGER.info(("Encountered non-fatal errors when configuring MySQL session that could " - "impact performance:")) - for warning in warnings: - LOGGER.warning(warning) + if session_sqls and isinstance(session_sqls, list): + for sql in session_sqls: + try: + run_sql(connection, sql) + except pymysql.err.InternalError: + warnings.append(f'Could not set session variable: {sql}') + + if warnings: + LOGGER.warning('Encountered non-fatal errors when configuring session that could impact performance:') + for warning in warnings: + LOGGER.warning(warning) - return connection + +def run_sql(connection, sql): + with connection.cursor() as cur: + cur.execute(sql) def parse_internal_hostname(hostname): @@ -91,7 +87,6 @@ def __init__(self, config): "port": int(config["port"]), "cursorclass": config.get("cursorclass") or pymysql.cursors.SSCursor, "connect_timeout": CONNECT_TIMEOUT_SECONDS, - "read_timeout": READ_TIMEOUT_SECONDS, "charset": "utf8", } @@ -142,6 +137,8 @@ def __init__(self, config): self.ctx.verify_mode = ssl.CERT_NONE self.client_flag |= CLIENT.SSL + self.session_sqls = config.get("session_sqls", DEFAULT_SESSION_SQLS) + def __enter__(self): return self diff --git a/tests/test_tap_mysql.py b/tests/test_tap_mysql.py index 598e7b3..1ea230a 100644 --- a/tests/test_tap_mysql.py +++ b/tests/test_tap_mysql.py @@ -1,10 +1,12 @@ import unittest +from unittest.mock import patch +import pymysql import singer import singer.metadata import tap_mysql -from tap_mysql.connection import connect_with_backoff +from tap_mysql.connection import connect_with_backoff, MySQLConnection try: import tests.utils as test_utils @@ -982,6 +984,78 @@ def tearDown(self) -> None: cursor.execute('DROP TABLE good_pk_tab;') +class MySQLConnectionMock(MySQLConnection): + """ + Mocked MySQLConnection class + """ + def __init__(self, config): + super().__init__(config) + + self.executed_queries = [] + + def run_sql(self, sql): + self.executed_queries.append(sql) + + +class TestSessionSqls(unittest.TestCase): + + def setUp(self) -> None: + self.executed_queries = [] + + def run_sql_mock(self, connection, sql): + if sql.startswith('INVALID-SQL'): + raise pymysql.err.InternalError + + self.executed_queries.append(sql) + + def test_open_connections_with_default_session_sqls(self): + """Default session parameters should be applied if no custom session SQLs""" + with patch('tap_mysql.connection.MySQLConnection.connect'): + with patch('tap_mysql.connection.run_sql') as run_sql_mock: + run_sql_mock.side_effect = self.run_sql_mock + conn = MySQLConnectionMock(config=test_utils.get_db_config()) + connect_with_backoff(conn) + + # Test if session variables applied on connection + self.assertEqual(self.executed_queries, tap_mysql.connection.DEFAULT_SESSION_SQLS) + + def test_open_connections_with_session_sqls(self): + """Custom session parameters should be applied if defined""" + session_sqls = [ + 'SET SESSION max_statement_time=0', + 'SET SESSION wait_timeout=28800' + ] + + with patch('tap_mysql.connection.MySQLConnection.connect'): + with patch('tap_mysql.connection.run_sql') as run_sql_mock: + run_sql_mock.side_effect = self.run_sql_mock + conn = MySQLConnectionMock(config={**test_utils.get_db_config(), + **{'session_sqls': session_sqls}}) + connect_with_backoff(conn) + + # Test if session variables applied on connection + self.assertEqual(self.executed_queries, session_sqls) + + def test_open_connections_with_invalid_session_sqls(self): + """Invalid SQLs in session_sqls should be ignored""" + session_sqls = [ + 'SET SESSION max_statement_time=0', + 'INVALID-SQL-SHOULD-BE-SILENTLY-IGNORED', + 'SET SESSION wait_timeout=28800' + ] + + with patch('tap_mysql.connection.MySQLConnection.connect'): + with patch('tap_mysql.connection.run_sql') as run_sql_mock: + run_sql_mock.side_effect = self.run_sql_mock + conn = MySQLConnectionMock(config={**test_utils.get_db_config(), + **{'session_sqls': session_sqls}}) + connect_with_backoff(conn) + + # Test if session variables applied on connection + self.assertEqual(self.executed_queries, ['SET SESSION max_statement_time=0', + 'SET SESSION wait_timeout=28800']) + + if __name__ == "__main__": test1 = TestBinlogReplication() test1.setUp() diff --git a/tests/utils.py b/tests/utils.py index 7da96ec..d298af3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ def get_db_config(): return config -def get_test_connection(): +def get_test_connection(extra_config=None): db_config = get_db_config() con = pymysql.connect(**db_config) @@ -38,7 +38,9 @@ def get_test_connection(): db_config['database'] = DB_NAME db_config['autocommit'] = True - mysql_conn = MySQLConnection(db_config) + if not extra_config: + extra_config = {} + mysql_conn = MySQLConnection({**db_config, **extra_config}) mysql_conn.autocommit_mode = True return mysql_conn