diff --git a/README.md b/README.md index 19ce405a..55545da2 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Full list of options in `config.json`: | archive_load_files | Boolean | | (Default: False) When enabled, the files loaded to Snowflake will also be stored in `archive_load_files_s3_bucket` under the key `/{archive_load_files_s3_prefix}/{schema_name}/{table_name}/`. All archived files will have `tap`, `schema`, `table` and `archived-by` as S3 metadata keys. When incremental replication is used, the archived files will also have the following S3 metadata keys: `incremental-key`, `incremental-key-min` and `incremental-key-max`. | archive_load_files_s3_prefix | String | | (Default: "archive") When `archive_load_files` is enabled, the archived files will be placed in the archive S3 bucket under this prefix. | archive_load_files_s3_bucket | String | | (Default: Value of `s3_bucket`) When `archive_load_files` is enabled, the archived files will be placed in this bucket. - +| max_memory_threshold | Number | | (Default: 0.0) Force flushing of all streams when total memory consumption is above threshold, e.g. a value of 0.9 will flush all streams when the total memory usage gets above 90%. A value of 0.0 disables the feature. When running this target in a container, e.g. on Kubernetes, make sure to set a memory limit for this feature to work properly. ### To run tests: 1. Define the environment variables that are required to run the tests by creating a `.env` file in `tests/integration`, or by exporting the variables below. diff --git a/setup.py b/setup.py index 4e12948d..9891ba33 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ 'joblib==1.1.0', 'ujson==5.2.0', 'boto3==1.21', + 'psutil==5.9.1' ], extras_require={ "test": [ diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index 7754421a..f7876d43 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -5,6 +5,7 @@ import ujson import logging import os +import psutil import sys import copy @@ -34,7 +35,7 @@ DEFAULT_BATCH_SIZE_ROWS = 100000 DEFAULT_PARALLELISM = 0 # 0 The number of threads used to flush tables DEFAULT_MAX_PARALLELISM = 16 # Don't use more than this number of threads by default when flushing streams in parallel - +MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS = 1000 def add_metadata_columns_to_schema(schema_message): """Metadata _sdc columns according to the stitch documentation at @@ -111,11 +112,13 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT row_count = {} stream_to_sync = {} total_row_count = {} + total_row_count_all_streams = 0 batch_size_rows = config.get('batch_size_rows', DEFAULT_BATCH_SIZE_ROWS) batch_wait_limit_seconds = config.get('batch_wait_limit_seconds', None) flush_timestamp = datetime.utcnow() archive_load_files = config.get('archive_load_files', False) archive_load_files_data = {} + max_memory_threshold = float(config.get('max_memory_threshold', 0)) # Loop over lines from stdin for line in lines: @@ -166,6 +169,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT if primary_key_string not in records_to_load[stream]: row_count[stream] += 1 total_row_count[stream] += 1 + total_row_count_all_streams += 1 # append record if config.get('add_metadata_columns') or config.get('hard_delete'): @@ -189,7 +193,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT stream_archive_load_files_values['max'] = incremental_key_value flush = False - if row_count[stream] >= batch_size_rows: + flush_all_streams = config.get('flush_all_streams') + + if max_memory_threshold and \ + total_row_count_all_streams % MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS == 0 and \ + current_memory_consumption_percentage() > max_memory_threshold: + flush = True + flush_all_streams = True + LOGGER.info("Flush triggered by memory threshold") + elif row_count[stream] >= batch_size_rows: flush = True LOGGER.info("Flush triggered by batch_size_rows (%s) reached in %s", batch_size_rows, stream) @@ -201,7 +213,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT if flush: # flush all streams, delete records if needed, reset counts and then emit current state - if config.get('flush_all_streams'): + if flush_all_streams: filter_streams = None else: filter_streams = [stream] @@ -335,6 +347,29 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT emit_state(copy.deepcopy(flushed_state)) + +def current_memory_consumption_percentage(): + return current_memory_usage_bytes() / memory_limit_bytes() + + +def current_memory_usage_bytes(): + # Try to read cgroup stats first in case we run in a container + try: + with open('/sys/fs/cgroup/memory/memory.usage_in_bytes', 'r') as f: + return int(f.readline().strip()) + except FileNotFoundError: + return psutil.virtual_memory().total - psutil.virtual_memory().available + + +def memory_limit_bytes(): + # Try to read cgroup stats first in case we run in a container + try: + with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: + return int(f.readline().strip()) + except FileNotFoundError: + return psutil.virtual_memory().total + + # pylint: disable=too-many-arguments def flush_streams( streams, diff --git a/tests/unit/test_target_snowflake.py b/tests/unit/test_target_snowflake.py index 63e34b27..a632ba04 100644 --- a/tests/unit/test_target_snowflake.py +++ b/tests/unit/test_target_snowflake.py @@ -3,10 +3,11 @@ import unittest import os import itertools +from types import SimpleNamespace from contextlib import redirect_stdout from datetime import datetime, timedelta -from unittest.mock import patch +from unittest.mock import patch, mock_open import target_snowflake @@ -41,6 +42,31 @@ def test_persist_lines_with_40_records_and_batch_size_of_20_expect_flushing_once self.assertEqual(1, flush_streams_mock.call_count) + @patch('target_snowflake.current_memory_consumption_percentage') + @patch('target_snowflake.flush_streams') + @patch('target_snowflake.DbSync') + @patch('target_snowflake.MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS', 1) + def test_persist_lines_with_memory_threshold_reached_expect_multiple_flushings(self, dbSync_mock, + flush_streams_mock, + current_memory_consumption_percentage_mock + ): + self.config['batch_size_rows'] = 20 + self.config['max_memory_threshold'] = 0.1 + + with open(f'{os.path.dirname(__file__)}/resources/same-schemas-multiple-times.json', 'r') as f: + lines = f.readlines() + + current_memory_consumption_percentage_mock.return_value = 0.2 + instance = dbSync_mock.return_value + instance.create_schema_if_not_exists.return_value = None + instance.sync_table.return_value = None + + flush_streams_mock.return_value = '{"currently_syncing": null}' + + target_snowflake.persist_lines(self.config, lines) + + self.assertEqual(5, flush_streams_mock.call_count) + @patch('target_snowflake.flush_streams') @patch('target_snowflake.DbSync') def test_persist_lines_with_same_schema_expect_flushing_once(self, dbSync_mock, @@ -174,3 +200,15 @@ def test_persist_lines_with_only_state_messages(self, dbSync_mock, flush_streams buf.getvalue().strip(), '{"bookmarks":{"tap_mysql_test-test_simple_table":{"replication_key":"id",' '"replication_key_value":100,"version":1}}}') + + def test_current_memory_consumption_percentage_cgroup(self): + with patch('builtins.open', mock_open(read_data='20'), create=True) as mock_builtin_open: + assert target_snowflake.current_memory_consumption_percentage() == 1.0 + + @patch('psutil.virtual_memory') + def test_current_memory_consumption_percentage_psutil(self, psutil_virtual_memory_mock): + vmem = SimpleNamespace() + vmem.available = 20 + vmem.total = 100 + psutil_virtual_memory_mock.return_value = vmem + assert target_snowflake.current_memory_consumption_percentage() == 0.8