diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index 7754421a..f13fe92d 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -166,12 +166,19 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT if primary_key_string not in records_to_load[stream]: row_count[stream] += 1 total_row_count[stream] += 1 + records_to_load[stream][primary_key_string] = {} - # append record + # merge record into batch if config.get('add_metadata_columns') or config.get('hard_delete'): - records_to_load[stream][primary_key_string] = stream_utils.add_metadata_values_to_record(o) + records_to_load[stream][primary_key_string] = merge_records( + records_to_load[stream][primary_key_string], + stream_utils.add_metadata_values_to_record(o) + ) else: - records_to_load[stream][primary_key_string] = o['record'] + records_to_load[stream][primary_key_string] = merge_records( + records_to_load[stream][primary_key_string], + o['record'] + ) if archive_load_files and stream in archive_load_files_data: # Keep track of min and max of the designated column @@ -334,6 +341,8 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT # emit latest state emit_state(copy.deepcopy(flushed_state)) +def merge_records(existing: dict, update: dict): + return {**existing, **update} # pylint: disable=too-many-arguments def flush_streams( diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index 38d7757a..dcb2129d 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -1386,6 +1386,13 @@ def test_deletion_does_not_set_column_data_to_null(self): for _column, value in subject[0].items(): self.assertIsNotNone(value) + # Insert and Delete for cid 4 in table logical1_table2 happens in a single batch. Validate that record message + # of the deletion does not overwrite all data from the insert within the batch. + subject = self.snowflake.query(f'SELECT cid, cvarchar, _sdc_deleted_at FROM' + f' {self.config["default_target_schema"]}.logical1_table2 WHERE cid = \'4\';') + for _column, value in subject[0].items(): + self.assertIsNotNone(value) + tap_lines_update = test_utils.get_test_tap_lines('messages-pg-logical-streams-update.json') self.persist_lines_with_cache(tap_lines_update) diff --git a/tests/unit/test_target_snowflake.py b/tests/unit/test_target_snowflake.py index 63e34b27..f96725fa 100644 --- a/tests/unit/test_target_snowflake.py +++ b/tests/unit/test_target_snowflake.py @@ -174,3 +174,13 @@ def test_persist_lines_with_only_state_messages(self, dbSync_mock, flush_streams buf.getvalue().strip(), '{"bookmarks":{"tap_mysql_test-test_simple_table":{"replication_key":"id",' '"replication_key_value":100,"version":1}}}') + + def test_merge_records(self): + existing_record = {'a': 1, 'b': None, 'c': 'foo', 'd': 1} + new_record = {'a': 2, 'c': None, 'e': '2'} + + merged_records = target_snowflake.merge_records(existing_record, new_record) + + expected = {'a': 2, 'b': None, 'c': None, 'd': 1, 'e': '2'} + + self.assertEqual(merged_records, expected)