Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tdl 17934 poc on rulemap implementation #132

Open
wants to merge 17 commits into
base: crest-master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 121 additions & 28 deletions tap_stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
import logging
import re
import copy


from datetime import datetime, timedelta
import stripe
Expand All @@ -12,6 +14,7 @@
import singer
from singer import utils, Transformer, metrics
from singer import metadata
from tap_stripe.rule_map import RuleMap

REQUIRED_CONFIG_KEYS = [
"start_date",
Expand Down Expand Up @@ -283,39 +286,96 @@ def load_schemas():

return schemas

def get_discovery_metadata(schema, key_properties, replication_method, replication_key):
def get_discovery_metadata(schema, key_properties, replication_method, replication_key, rule_map, stream_name):
mdata = metadata.new()
mdata = metadata.write(mdata, (), 'table-key-properties', key_properties)
mdata = metadata.write(mdata, (), 'forced-replication-method', replication_method)

if replication_key:
mdata = metadata.write(mdata, (), 'valid-replication-keys', [replication_key])

if 'stream_name' in rule_map:
# Write original-name of stream name in top level metadata
mdata = metadata.write(mdata, (), 'original-name', stream_name)

for field_name in schema['properties'].keys():
if field_name in key_properties or field_name in [replication_key, "updated"]:
mdata = metadata.write(mdata, ('properties', field_name), 'inclusion', 'automatic')
else:
mdata = metadata.write(mdata, ('properties', field_name), 'inclusion', 'available')

# Add metadata for nested(child) fields also if it's name is changed from original name.
add_child_into_metadata(schema['properties'][field_name], metadata, mdata, rule_map, ('properties', field_name), )
if ('properties', field_name) in rule_map:
mdata.get(('properties', field_name)).update({'original-name': rule_map[('properties', field_name)]})

return metadata.to_list(mdata)

def add_child_into_metadata(schema, m_data, mdata, rule_map, parent=()):
"""
Add metadata for nested(child) fields also if it's name is changed from original name.
"""
if schema and isinstance(schema, dict) and schema.get('properties'):
for key in schema['properties'].keys():
# prepare key to find original-name of field in rule_map object
# Key is tuple of items found in breadcrumb.
breadcrumb = parent + ('properties', key)

# Iterate in recursive manner to go through each field of schema.
add_child_into_metadata(schema['properties'][key], m_data, mdata, rule_map, breadcrumb)

mdata = m_data.write(mdata, breadcrumb, 'inclusion', 'available')

if breadcrumb in rule_map:
# Add `original-name` field in metadata which contain actual name of field.
mdata.get(breadcrumb).update({'original-name': rule_map[breadcrumb]})

if schema.get('anyOf'):
for schema_fields in schema.get('anyOf'):
add_child_into_metadata(schema_fields, m_data, mdata, rule_map, parent)

def discover():
if schema and isinstance(schema, dict) and schema.get('items'):
breadcrumb = parent + ('items',)
add_child_into_metadata(schema['items'], m_data, mdata, rule_map, breadcrumb)

def discover(rule_map):
raw_schemas = load_schemas()
streams = []

for stream_name, stream_map in STREAM_SDK_OBJECTS.items():
schema = raw_schemas[stream_name]['schema']
refs = load_shared_schema_refs()

# Get resolved schema
resolved_schema = singer.resolve_schema_references(schema, refs)

# Define stream_name in GetStdFieldsFromApiFields
rule_map.GetStdFieldsFromApiFields[stream_name] = {}

# We face issue regarding ref. In some of the schema same ref is being used.
# When we change fields of one of the ref, changes reflect in all the places where the same ref is being used.
# Due to this, the `original-name` field name was missing in the metadata of the catalog.
# So, to prevent change in the actual schema, here we are creating a deep copy of schema and updating deep copy.
# We do not update the actual schema
copied_schema = copy.deepcopy(resolved_schema)

# Get updated schema by applying rule map
standard_resolved_schema = rule_map.apply_ruleset_on_schema(resolved_schema, copied_schema, stream_name)

# Get standard name of stream
standard_stream_name = rule_map.apply_rule_set_on_stream_name(stream_name)

# create and add catalog entry
catalog_entry = {
'stream': stream_name,
'tap_stream_id': stream_name,
'schema': singer.resolve_schema_references(schema, refs),
'metadata': get_discovery_metadata(schema,
'stream': standard_stream_name,
'tap_stream_id': standard_stream_name,
'schema': standard_resolved_schema,
'metadata': get_discovery_metadata(standard_resolved_schema,
stream_map['key_properties'],
'INCREMENTAL',
STREAM_REPLICATION_KEY.get(stream_name)),
STREAM_REPLICATION_KEY.get(stream_name),
rule_map.GetStdFieldsFromApiFields[stream_name],
stream_name),
# Events may have a different key property than this. Change
# if it's appropriate.
'key_properties': stream_map['key_properties']
Expand Down Expand Up @@ -470,13 +530,17 @@ def convert_dict_to_stripe_object(record):

# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
def sync_stream(stream_name):
def sync_stream(stream_name, api_stream_name, rule_map):
"""
Sync each stream, looking for newly created records. Updates are captured by events stream.
"""
LOGGER.info("Started syncing stream %s", stream_name)

stream_metadata = metadata.to_map(Context.get_catalog_entry(stream_name)['metadata'])

# Fill rule_map object by original-name available in metadata
rule_map.fill_rule_map_object_by_catalog(stream_name, stream_metadata)

stream_field_whitelist = json.loads(Context.config.get('whitelist_map', '{}')).get(stream_name)

extraction_time = singer.utils.now()
Expand All @@ -490,7 +554,13 @@ def sync_stream(stream_name):
bookmark = stream_bookmark

# if this stream has a sub_stream, compare the bookmark
sub_stream_name = SUB_STREAMS.get(stream_name)
if SUB_STREAMS.get(api_stream_name):
sub_stream_name = rule_map.apply_rule_set_on_stream_name(SUB_STREAMS.get(api_stream_name))
else:
sub_stream_name = SUB_STREAMS.get(stream_name)


#sub_stream_name = SUB_STREAMS.get(stream_name)

# If there is a sub-stream and its selected, get its bookmark (or the start date if no bookmark)
should_sync_sub_stream = sub_stream_name and Context.is_selected(sub_stream_name)
Expand Down Expand Up @@ -519,7 +589,7 @@ def sync_stream(stream_name):
# observed a short lag period between when records are created and
# when they are available via the API, so these streams will need
# a short lookback window.
if stream_name in IMMUTABLE_STREAMS:
if api_stream_name in IMMUTABLE_STREAMS:
# pylint:disable=fixme
# TODO: This may be an issue for other streams' created_at
# entries, but to keep the surface small, doing this only for
Expand All @@ -536,24 +606,26 @@ def sync_stream(stream_name):
stop_window = end_time

for stream_obj in paginate(
STREAM_SDK_OBJECTS[stream_name]['sdk_object'],
STREAM_SDK_OBJECTS[api_stream_name]['sdk_object'],
filter_key,
start_window,
stop_window,
stream_name,
STREAM_SDK_OBJECTS[stream_name].get('request_args')
api_stream_name,
STREAM_SDK_OBJECTS[api_stream_name].get('request_args')
):

# get the replication key value from the object
rec = unwrap_data_objects(stream_obj.to_dict_recursive())
# convert field datatype of dict object to `stripe.stripe_object.StripeObject`
rec = convert_dict_to_stripe_object(rec)
rec = reduce_foreign_keys(rec, stream_name)
rec = reduce_foreign_keys(rec, api_stream_name)
stream_obj_created = rec[replication_key]
rec['updated'] = stream_obj_created

# sync stream if object is greater than or equal to the bookmark
if stream_obj_created >= stream_bookmark:
rec = rule_map.apply_ruleset_on_api_response(rec, stream_name)

rec = transformer.transform(rec,
Context.get_catalog_entry(stream_name)['schema'],
stream_metadata)
Expand All @@ -573,7 +645,12 @@ def sync_stream(stream_name):
# sync sub streams if its selected and the parent object
# is greater than its bookmark
if should_sync_sub_stream and stream_obj_created > sub_stream_bookmark:
sync_sub_stream(sub_stream_name, stream_obj)
# Fill api-name in rule_map object
rule_map.fill_rule_map_object_by_catalog(sub_stream_name, metadata.to_map(
Context.get_catalog_entry(sub_stream_name)['metadata']
))

sync_sub_stream(sub_stream_name, stream_obj, rule_map)

# Update stream/sub-streams bookmarks as stop window
if stop_window > stream_bookmark:
Expand Down Expand Up @@ -621,7 +698,7 @@ def get_object_list_iterator(object_list):
# we are in a cycle.
INITIAL_SUB_STREAM_OBJECT_LIST_LENGTH = 10

def sync_sub_stream(sub_stream_name, parent_obj, updates=False):
def sync_sub_stream(sub_stream_name, parent_obj, rule_map, updates=False):
"""
Given a parent object, retrieve its values for the specified substream.
"""
Expand Down Expand Up @@ -711,7 +788,8 @@ def sync_sub_stream(sub_stream_name, parent_obj, updates=False):
# payout_transactions is a join table
obj_ad_dict = {"id": obj_ad_dict['id'], "payout_id": parent_obj['id']}

rec = transformer.transform(unwrap_data_objects(obj_ad_dict),
rec = rule_map.apply_ruleset_on_api_response(unwrap_data_objects(obj_ad_dict), sub_stream_name)
rec = transformer.transform(rec,
Context.get_catalog_entry(sub_stream_name)['schema'],
metadata.to_map(
Context.get_catalog_entry(sub_stream_name)['metadata']
Expand Down Expand Up @@ -762,7 +840,7 @@ def recursive_to_dict(some_obj):
# Else just return
return some_obj

def sync_event_updates(stream_name):
def sync_event_updates(stream_name, api_stream_name, rule_map):
'''
Get updates via events endpoint

Expand Down Expand Up @@ -804,12 +882,16 @@ def sync_event_updates(stream_name):

for events_obj in response.auto_paging_iter():
event_resource_obj = events_obj.data.object
sub_stream_name = SUB_STREAMS.get(stream_name)

if SUB_STREAMS.get(stream_name):
sub_stream_name = rule_map.apply_rule_set_on_stream_name(SUB_STREAMS.get(api_stream_name))
else:
sub_stream_name = SUB_STREAMS.get(stream_name)


# Check whether we should sync the event based on its created time
if not should_sync_event(events_obj,
STREAM_TO_TYPE_FILTER[stream_name]['object'],
STREAM_TO_TYPE_FILTER[api_stream_name]['object'],
updated_object_timestamps):
continue

Expand All @@ -834,6 +916,8 @@ def sync_event_updates(stream_name):
rec = unwrap_data_objects(rec)
rec = reduce_foreign_keys(rec, stream_name)
rec["updated"] = events_obj.created
rec = rule_map.apply_ruleset_on_api_response(rec, stream_name)

rec["updated_by_event_type"] = events_obj.type
rec = transformer.transform(
rec,
Expand All @@ -856,6 +940,7 @@ def sync_event_updates(stream_name):
if event_resource_obj:
sync_sub_stream(sub_stream_name,
event_resource_obj,
rule_map,
updates=True)
if events_obj.created > max_created:
max_created = events_obj.created
Expand All @@ -872,7 +957,7 @@ def sync_event_updates(stream_name):

singer.write_state(Context.state)

def sync():
def sync(rule_map):
# Write all schemas and init count to 0
for catalog_entry in Context.catalog['streams']:
stream_name = catalog_entry["tap_stream_id"]
Expand All @@ -887,36 +972,44 @@ def sync():
# Loop over streams in catalog
for catalog_entry in Context.catalog['streams']:
stream_name = catalog_entry['tap_stream_id']

if catalog_entry.get('metadata')[0].get('metadata').get('api-name'):
api_stream_name = catalog_entry.get('metadata')[0].get('metadata').get('api-name')
else:
api_stream_name = stream_name

# Sync records for stream
if Context.is_selected(stream_name) and not Context.is_sub_stream(stream_name):
sync_stream(stream_name)
if Context.is_selected(stream_name) and not Context.is_sub_stream(api_stream_name):
sync_stream(stream_name, api_stream_name, rule_map)
# This prevents us from retrieving 'events.events'
if STREAM_TO_TYPE_FILTER.get(stream_name):
sync_event_updates(stream_name)
if STREAM_TO_TYPE_FILTER.get(api_stream_name):
sync_event_updates(stream_name, api_stream_name, rule_map)

@utils.handle_top_exception(LOGGER)
def main():
# Parse command line arguments
args = utils.parse_args(REQUIRED_CONFIG_KEYS)

rule_map = RuleMap()

# If discover flag was passed, run discovery mode and dump output to stdout
if args.discover:
catalog = discover()
catalog = discover(rule_map)
print(json.dumps(catalog, indent=2))
# Otherwise run in sync mode
else:
Context.tap_start = utils.now()
if args.catalog:
Context.catalog = args.catalog.to_dict()
else:
Context.catalog = discover()
Context.catalog = discover(rule_map)

Context.config = args.config
Context.state = args.state
configure_stripe_client()
validate_dependencies()
try:
sync()
sync(rule_map)
finally:
# Print counts
Context.print_counts()
Expand Down
Loading