Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Add option to flush streams when memory threshold reached #283

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Full list of options in `config.json`:
| archive_load_files | Boolean | | (Default: False) When enabled, the files loaded to Snowflake will also be stored in `archive_load_files_s3_bucket` under the key `/{archive_load_files_s3_prefix}/{schema_name}/{table_name}/`. All archived files will have `tap`, `schema`, `table` and `archived-by` as S3 metadata keys. When incremental replication is used, the archived files will also have the following S3 metadata keys: `incremental-key`, `incremental-key-min` and `incremental-key-max`.
| archive_load_files_s3_prefix | String | | (Default: "archive") When `archive_load_files` is enabled, the archived files will be placed in the archive S3 bucket under this prefix.
| archive_load_files_s3_bucket | String | | (Default: Value of `s3_bucket`) When `archive_load_files` is enabled, the archived files will be placed in this bucket.

| max_memory_threshold | Number | | (Default: 0.0) Force flushing of all streams when total memory consumption is above threshold, e.g. a value of 0.9 will flush all streams when the total memory usage gets above 90%. A value of 0.0 disables the feature. When running this target in a container, e.g. on Kubernetes, make sure to set a memory limit for this feature to work properly.
### To run tests:

1. Define the environment variables that are required to run the tests by creating a `.env` file in `tests/integration`, or by exporting the variables below.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
'joblib==1.1.0',
'ujson==5.2.0',
'boto3==1.21',
'psutil==5.9.1'
],
extras_require={
"test": [
Expand Down
41 changes: 38 additions & 3 deletions target_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ujson
import logging
import os
import psutil
import sys
import copy

Expand Down Expand Up @@ -34,7 +35,7 @@
DEFAULT_BATCH_SIZE_ROWS = 100000
DEFAULT_PARALLELISM = 0 # 0 The number of threads used to flush tables
DEFAULT_MAX_PARALLELISM = 16 # Don't use more than this number of threads by default when flushing streams in parallel

MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS = 1000

def add_metadata_columns_to_schema(schema_message):
"""Metadata _sdc columns according to the stitch documentation at
Expand Down Expand Up @@ -111,11 +112,13 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
row_count = {}
stream_to_sync = {}
total_row_count = {}
total_row_count_all_streams = 0
batch_size_rows = config.get('batch_size_rows', DEFAULT_BATCH_SIZE_ROWS)
batch_wait_limit_seconds = config.get('batch_wait_limit_seconds', None)
flush_timestamp = datetime.utcnow()
archive_load_files = config.get('archive_load_files', False)
archive_load_files_data = {}
max_memory_threshold = float(config.get('max_memory_threshold', 0))

# Loop over lines from stdin
for line in lines:
Expand Down Expand Up @@ -166,6 +169,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
if primary_key_string not in records_to_load[stream]:
row_count[stream] += 1
total_row_count[stream] += 1
total_row_count_all_streams += 1

# append record
if config.get('add_metadata_columns') or config.get('hard_delete'):
Expand All @@ -189,7 +193,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
stream_archive_load_files_values['max'] = incremental_key_value

flush = False
if row_count[stream] >= batch_size_rows:
flush_all_streams = config.get('flush_all_streams')

if max_memory_threshold and \
total_row_count_all_streams % MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS == 0 and \
current_memory_consumption_percentage() > max_memory_threshold:
flush = True
flush_all_streams = True
LOGGER.info("Flush triggered by memory threshold")
elif row_count[stream] >= batch_size_rows:
flush = True
LOGGER.info("Flush triggered by batch_size_rows (%s) reached in %s",
batch_size_rows, stream)
Expand All @@ -201,7 +213,7 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT

if flush:
# flush all streams, delete records if needed, reset counts and then emit current state
if config.get('flush_all_streams'):
if flush_all_streams:
filter_streams = None
else:
filter_streams = [stream]
Expand Down Expand Up @@ -335,6 +347,29 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
emit_state(copy.deepcopy(flushed_state))



def current_memory_consumption_percentage():
return current_memory_usage_bytes() / memory_limit_bytes()


def current_memory_usage_bytes():
# Try to read cgroup stats first in case we run in a container
try:
with open('/sys/fs/cgroup/memory/memory.usage_in_bytes', 'r') as f:
return int(f.readline().strip())
except FileNotFoundError:
return psutil.virtual_memory().total - psutil.virtual_memory().available


def memory_limit_bytes():
# Try to read cgroup stats first in case we run in a container
try:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
return int(f.readline().strip())
except FileNotFoundError:
return psutil.virtual_memory().total


# pylint: disable=too-many-arguments
def flush_streams(
streams,
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/test_target_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import unittest
import os
import itertools
from types import SimpleNamespace

from contextlib import redirect_stdout
from datetime import datetime, timedelta
from unittest.mock import patch
from unittest.mock import patch, mock_open

import target_snowflake

Expand Down Expand Up @@ -41,6 +42,31 @@ def test_persist_lines_with_40_records_and_batch_size_of_20_expect_flushing_once

self.assertEqual(1, flush_streams_mock.call_count)

@patch('target_snowflake.current_memory_consumption_percentage')
@patch('target_snowflake.flush_streams')
@patch('target_snowflake.DbSync')
@patch('target_snowflake.MAX_MEMORY_THRESHOLD_CHECK_EVERY_N_ROWS', 1)
def test_persist_lines_with_memory_threshold_reached_expect_multiple_flushings(self, dbSync_mock,
flush_streams_mock,
current_memory_consumption_percentage_mock
):
self.config['batch_size_rows'] = 20
self.config['max_memory_threshold'] = 0.1

with open(f'{os.path.dirname(__file__)}/resources/same-schemas-multiple-times.json', 'r') as f:
lines = f.readlines()

current_memory_consumption_percentage_mock.return_value = 0.2
instance = dbSync_mock.return_value
instance.create_schema_if_not_exists.return_value = None
instance.sync_table.return_value = None

flush_streams_mock.return_value = '{"currently_syncing": null}'

target_snowflake.persist_lines(self.config, lines)

self.assertEqual(5, flush_streams_mock.call_count)

@patch('target_snowflake.flush_streams')
@patch('target_snowflake.DbSync')
def test_persist_lines_with_same_schema_expect_flushing_once(self, dbSync_mock,
Expand Down Expand Up @@ -174,3 +200,15 @@ def test_persist_lines_with_only_state_messages(self, dbSync_mock, flush_streams
buf.getvalue().strip(),
'{"bookmarks":{"tap_mysql_test-test_simple_table":{"replication_key":"id",'
'"replication_key_value":100,"version":1}}}')

def test_current_memory_consumption_percentage_cgroup(self):
with patch('builtins.open', mock_open(read_data='20'), create=True) as mock_builtin_open:
assert target_snowflake.current_memory_consumption_percentage() == 1.0

@patch('psutil.virtual_memory')
def test_current_memory_consumption_percentage_psutil(self, psutil_virtual_memory_mock):
vmem = SimpleNamespace()
vmem.available = 20
vmem.total = 100
psutil_virtual_memory_mock.return_value = vmem
assert target_snowflake.current_memory_consumption_percentage() == 0.8