From 838da8f01356fe368e82263a9794ba386546a8ee Mon Sep 17 00:00:00 2001 From: Nils Mueller Date: Tue, 24 May 2022 03:12:23 +0300 Subject: [PATCH] Add option to flush streams when memory threshold reached The memory consumption of a job is often not foreseeable, especially when the amount of data is highly inconsistent. It would either require to set the max batch size low enough to accommodate every possible load pattern, or to over-provision memory. This change adds a safeguard against OOM events by flushing all streams as soon as the memory consumption reaches a certain threshold. --- README.md | 2 +- setup.py | 1 + target_snowflake/__init__.py | 41 ++++++++++++++++++++++++++--- tests/unit/test_target_snowflake.py | 40 +++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 19ce405a..55545da2 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Full list of options in `config.json`: | archive_load_files | Boolean | | (Default: False) When enabled, the files loaded to Snowflake will also be stored in `archive_load_files_s3_bucket` under the key `/{archive_load_files_s3_prefix}/{schema_name}/{table_name}/`. All archived files will have `tap`, `schema`, `table` and `archived-by` as S3 metadata keys. When incremental replication is used, the archived files will also have the following S3 metadata keys: `incremental-key`, `incremental-key-min` and `incremental-key-max`. | archive_load_files_s3_prefix | String | | (Default: "archive") When `archive_load_files` is enabled, the archived files will be placed in the archive S3 bucket under this prefix. | archive_load_files_s3_bucket | String | | (Default: Value of `s3_bucket`) When `archive_load_files` is enabled, the archived files will be placed in this bucket. - +| max_memory_threshold | Number | | (Default: 0.0) Force flushing of all streams when total memory consumption is above threshold, e.g. a value of 0.9 will flush all streams when the total memory usage gets above 90%. A value of 0.0 disables the feature. When running this target in a container, e.g. on Kubernetes, make sure to set a memory limit for this feature to work properly. ### To run tests: 1. Define the environment variables that are required to run the tests by creating a `.env` file in `tests/integration`, or by exporting the variables below. diff --git a/setup.py b/setup.py index 4e12948d..9891ba33 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ 'joblib==1.1.0', 'ujson==5.2.0', 'boto3==1.21', + 'psutil==5.9.1' ], extras_require={ "test": [ diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index 7754421a..f7876d43 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -5,6 +5,7 @@ import ujson import logging import os +import psutil import sys import copy @@ -34,7 +35,7 @@ DEFAULT_BATCH_SIZE_ROWS = 100000 DEFAULT_PARALLELISM = 0 # 0 The number of threads used to flush tables DEFAULT_MAX_PARALLELISM = 16 # Don't use more than this number of threads by default when flushing streams in parallel - +MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS = 1000 def add_metadata_columns_to_schema(schema_message): """Metadata _sdc columns according to the stitch documentation at @@ -111,11 +112,13 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT row_count = {} stream_to_sync = {} total_row_count = {} + total_row_count_all_streams = 0 batch_size_rows = config.get('batch_size_rows', DEFAULT_BATCH_SIZE_ROWS) batch_wait_limit_seconds = config.get('batch_wait_limit_seconds', None) flush_timestamp = datetime.utcnow() archive_load_files = config.get('archive_load_files', False) archive_load_files_data = {} + max_memory_threshold = float(config.get('max_memory_threshold', 0)) # Loop over lines from stdin for line in lines: @@ -166,6 +169,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT if primary_key_string not in records_to_load[stream]: row_count[stream] += 1 total_row_count[stream] += 1 + total_row_count_all_streams += 1 # append record if config.get('add_metadata_columns') or config.get('hard_delete'): @@ -189,7 +193,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT stream_archive_load_files_values['max'] = incremental_key_value flush = False - if row_count[stream] >= batch_size_rows: + flush_all_streams = config.get('flush_all_streams') + + if max_memory_threshold and \ + total_row_count_all_streams % MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS == 0 and \ + current_memory_consumption_percentage() > max_memory_threshold: + flush = True + flush_all_streams = True + LOGGER.info("Flush triggered by memory threshold") + elif row_count[stream] >= batch_size_rows: flush = True LOGGER.info("Flush triggered by batch_size_rows (%s) reached in %s", batch_size_rows, stream) @@ -201,7 +213,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT if flush: # flush all streams, delete records if needed, reset counts and then emit current state - if config.get('flush_all_streams'): + if flush_all_streams: filter_streams = None else: filter_streams = [stream] @@ -335,6 +347,29 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT emit_state(copy.deepcopy(flushed_state)) + +def current_memory_consumption_percentage(): + return current_memory_usage_bytes() / memory_limit_bytes() + + +def current_memory_usage_bytes(): + # Try to read cgroup stats first in case we run in a container + try: + with open('/sys/fs/cgroup/memory/memory.usage_in_bytes', 'r') as f: + return int(f.readline().strip()) + except FileNotFoundError: + return psutil.virtual_memory().total - psutil.virtual_memory().available + + +def memory_limit_bytes(): + # Try to read cgroup stats first in case we run in a container + try: + with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: + return int(f.readline().strip()) + except FileNotFoundError: + return psutil.virtual_memory().total + + # pylint: disable=too-many-arguments def flush_streams( streams, diff --git a/tests/unit/test_target_snowflake.py b/tests/unit/test_target_snowflake.py index 63e34b27..a632ba04 100644 --- a/tests/unit/test_target_snowflake.py +++ b/tests/unit/test_target_snowflake.py @@ -3,10 +3,11 @@ import unittest import os import itertools +from types import SimpleNamespace from contextlib import redirect_stdout from datetime import datetime, timedelta -from unittest.mock import patch +from unittest.mock import patch, mock_open import target_snowflake @@ -41,6 +42,31 @@ def test_persist_lines_with_40_records_and_batch_size_of_20_expect_flushing_once self.assertEqual(1, flush_streams_mock.call_count) + @patch('target_snowflake.current_memory_consumption_percentage') + @patch('target_snowflake.flush_streams') + @patch('target_snowflake.DbSync') + @patch('target_snowflake.MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS', 1) + def test_persist_lines_with_memory_threshold_reached_expect_multiple_flushings(self, dbSync_mock, + flush_streams_mock, + current_memory_consumption_percentage_mock + ): + self.config['batch_size_rows'] = 20 + self.config['max_memory_threshold'] = 0.1 + + with open(f'{os.path.dirname(__file__)}/resources/same-schemas-multiple-times.json', 'r') as f: + lines = f.readlines() + + current_memory_consumption_percentage_mock.return_value = 0.2 + instance = dbSync_mock.return_value + instance.create_schema_if_not_exists.return_value = None + instance.sync_table.return_value = None + + flush_streams_mock.return_value = '{"currently_syncing": null}' + + target_snowflake.persist_lines(self.config, lines) + + self.assertEqual(5, flush_streams_mock.call_count) + @patch('target_snowflake.flush_streams') @patch('target_snowflake.DbSync') def test_persist_lines_with_same_schema_expect_flushing_once(self, dbSync_mock, @@ -174,3 +200,15 @@ def test_persist_lines_with_only_state_messages(self, dbSync_mock, flush_streams buf.getvalue().strip(), '{"bookmarks":{"tap_mysql_test-test_simple_table":{"replication_key":"id",' '"replication_key_value":100,"version":1}}}') + + def test_current_memory_consumption_percentage_cgroup(self): + with patch('builtins.open', mock_open(read_data='20'), create=True) as mock_builtin_open: + assert target_snowflake.current_memory_consumption_percentage() == 1.0 + + @patch('psutil.virtual_memory') + def test_current_memory_consumption_percentage_psutil(self, psutil_virtual_memory_mock): + vmem = SimpleNamespace() + vmem.available = 20 + vmem.total = 100 + psutil_virtual_memory_mock.return_value = vmem + assert target_snowflake.current_memory_consumption_percentage() == 0.8