Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
[AP-822] S3 profile based auth, session_token, AWS stage creds from e…
Browse files Browse the repository at this point in the history
…nv vars (#93)
  • Loading branch information
koszti authored Jul 30, 2020
1 parent 0fe359b commit 85ebb71
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 24 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,15 @@ Full list of options in `config.json`:
| user | String | Yes | Snowflake User |
| password | String | Yes | Snowflake Password |
| warehouse | String | Yes | Snowflake virtual warehouse name |
| aws_access_key_id | String | No | S3 Access Key Id. If not provided, AWS_ACCESS_KEY_ID environment variable or IAM role will be used |
| aws_secret_access_key | String | No | S3 Secret Access Key. If not provided, AWS_SECRET_ACCESS_KEY environment variable or IAM role will be used |
| aws_session_token | String | No | AWS Session token. If not provided, AWS_SESSION_TOKEN environment variable will be used |
| s3_acl | String | No | S3 ACL name |
| aws_access_key_id | String | No | S3 Access Key Id. If not provided, `AWS_ACCESS_KEY_ID` environment variable or IAM role will be used |
| aws_secret_access_key | String | No | S3 Secret Access Key. If not provided, `AWS_SECRET_ACCESS_KEY` environment variable or IAM role will be used |
| aws_session_token | String | No | AWS Session token. If not provided, `AWS_SESSION_TOKEN` environment variable will be used |
| aws_profile | String | No | AWS profile name for profile based authentication. If not provided, `AWS_PROFILE` environment variable will be used. |
| s3_bucket | String | Yes | S3 Bucket name |
| s3_key_prefix | String | No | (Default: None) A static prefix before the generated S3 key names. Using prefixes you can upload files into specific directories in the S3 bucket. |
| s3_endpoint_url | String | No | The complete URL to use for the constructed client. This is allowing to use non-native s3 account. |
| s3_region_name | String | No | Default region when creating new connections |
| s3_acl | String | No | S3 ACL name to set on the uploaded files |
| stage | String | Yes | Named external stage name created at pre-requirements section. Has to be a fully qualified name including the schema name |
| file_format | String | Yes | Named file format name created at pre-requirements section. Has to be a fully qualified name including the schema name. |
| batch_size_rows | Integer | | (Default: 100000) Maximum number of rows in each batch. At the end of each batch, the rows in the batch are loaded into Snowflake. |
Expand Down
33 changes: 27 additions & 6 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,33 @@ def __init__(self, connection_config, stream_schema_message=None, table_cache=No
self.data_flattening_max_level = self.connection_config.get('data_flattening_max_level', 0)
self.flatten_schema = flatten_schema(stream_schema_message['schema'], max_level=self.data_flattening_max_level)

self.s3 = boto3.client(
's3',
aws_access_key_id=self.connection_config.get('aws_access_key_id'),
aws_secret_access_key=self.connection_config.get('aws_secret_access_key'),
aws_session_token=self.connection_config.get('aws_session_token')
)
self.s3 = self.create_s3_client()

def create_s3_client(self, config=None):
if not config:
config = self.connection_config

# Get the required parameters from config file and/or environment variables
aws_profile = config.get('aws_profile') or os.environ.get('AWS_PROFILE')
aws_access_key_id = config.get('aws_access_key_id') or os.environ.get('AWS_ACCESS_KEY_ID')
aws_secret_access_key = config.get('aws_secret_access_key') or os.environ.get('AWS_SECRET_ACCESS_KEY')
aws_session_token = config.get('aws_session_token') or os.environ.get('AWS_SESSION_TOKEN')

# AWS credentials based authentication
if aws_access_key_id and aws_secret_access_key:
aws_session = boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token
)
# AWS Profile based authentication
else:
aws_session = boto3.session.Session(profile_name=aws_profile)

# Create the s3 client
return aws_session.client('s3',
region_name=config.get('s3_region_name'),
endpoint_url=config.get('s3_endpoint_url'))

def open_connection(self):
return snowflake.connector.connect(
Expand Down
94 changes: 80 additions & 14 deletions tests/integration/test_target_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import mock
import os
import botocore

from nose.tools import assert_raises

Expand Down Expand Up @@ -915,24 +916,89 @@ def test_loading_tables_with_custom_temp_dir(self):

self.assert_three_streams_are_into_snowflake()

def test_using_aws_environment_variables(self):
"""Test loading data with aws in the environment rather than explicitly provided access keys"""
def test_aws_env_vars(self):
"""Test loading data with credentials defined in AWS environment variables
than explicitly provided access keys"""
tap_lines = test_utils.get_test_tap_lines("messages-with-three-streams.json")

try:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"TARGET_SNOWFLAKE_AWS_ACCESS_KEY"
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"TARGET_SNOWFLAKE_AWS_SECRET_ACCESS_KEY"
)
self.config["aws_access_key_id"] = None
self.config["aws_secret_access_key"] = None

target_snowflake.persist_lines(self.config, tap_lines)
# Save original config to restore later
orig_config = self.config.copy()

# Move aws access key and secret from config into environment variables
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('TARGET_SNOWFLAKE_AWS_ACCESS_KEY')
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('TARGET_SNOWFLAKE_AWS_SECRET_ACCESS_KEY')
del self.config['aws_access_key_id']
del self.config['aws_secret_access_key']

# Create a new S3 client using env vars
snowflake = DbSync(self.config)
snowflake.create_s3_client()

# Restore the original state to not confuse other tests
finally:
del os.environ['AWS_ACCESS_KEY_ID']
del os.environ['AWS_SECRET_ACCESS_KEY']
self.config = orig_config.copy()

def test_profile_based_auth(self):
"""Test AWS profile based authentication rather than access keys"""
try:
# Save original config to restore later
orig_config = self.config.copy()

# Remove access keys from config and add profile name
del self.config['aws_access_key_id']
del self.config['aws_secret_access_key']
self.config['aws_profile'] = 'fake-profile'

# Create a new S3 client using profile based authentication
with assert_raises(botocore.exceptions.ProfileNotFound):
snowflake = DbSync(self.config)
snowflake.create_s3_client()

# Restore the original state to not confuse other tests
finally:
self.config = orig_config.copy()

def test_profile_based_auth_aws_env_var(self):
"""Test AWS profile based authentication using AWS environment variables"""
try:
# Save original config to restore later
orig_config = self.config.copy()

# Remove access keys from config and add profile name environment variable
del self.config['aws_access_key_id']
del self.config['aws_secret_access_key']
os.environ['AWS_PROFILE'] = 'fake_profile'

# Create a new S3 client using profile based authentication
with assert_raises(botocore.exceptions.ProfileNotFound):
snowflake = DbSync(self.config)
snowflake.create_s3_client()

# Restore the original state to not confuse other tests
finally:
del os.environ['AWS_PROFILE']
self.config = orig_config.copy()

def test_s3_custom_endpoint_url(self):
"""Test S3 connection with custom region and endpoint URL"""
try:
# Save original config to restore later
orig_config = self.config.copy()

# Define custom S3 endpoint
self.config['s3_endpoint_url'] = 'fake-endpoint-url'

# Botocore should raise ValurError in case of fake S3 endpoint url
with assert_raises(ValueError):
snowflake = DbSync(self.config)
snowflake.create_s3_client()

# Restore the original state to not confuse other tests
finally:
del os.environ["AWS_ACCESS_KEY_ID"]
del os.environ["AWS_SECRET_ACCESS_KEY"]
self.config = orig_config.copy()

def test_too_many_records_exception(self):
"""Test if query function raise exception if max_records exceeded"""
Expand Down

0 comments on commit 85ebb71

Please sign in to comment.