From 85ebb71809a46b6a804cb64b9d42e09e716c6cce Mon Sep 17 00:00:00 2001 From: Peter Kosztolanyi Date: Thu, 30 Jul 2020 19:13:44 +0200 Subject: [PATCH] [AP-822] S3 profile based auth, session_token, AWS stage creds from env vars (#93) --- README.md | 11 ++- target_snowflake/db_sync.py | 33 ++++++-- tests/integration/test_target_snowflake.py | 94 ++++++++++++++++++---- 3 files changed, 114 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index a2e808f0..f3186966 100644 --- a/README.md +++ b/README.md @@ -131,12 +131,15 @@ Full list of options in `config.json`: | user | String | Yes | Snowflake User | | password | String | Yes | Snowflake Password | | warehouse | String | Yes | Snowflake virtual warehouse name | -| aws_access_key_id | String | No | S3 Access Key Id. If not provided, AWS_ACCESS_KEY_ID environment variable or IAM role will be used | -| aws_secret_access_key | String | No | S3 Secret Access Key. If not provided, AWS_SECRET_ACCESS_KEY environment variable or IAM role will be used | -| aws_session_token | String | No | AWS Session token. If not provided, AWS_SESSION_TOKEN environment variable will be used | -| s3_acl | String | No | S3 ACL name | +| aws_access_key_id | String | No | S3 Access Key Id. If not provided, `AWS_ACCESS_KEY_ID` environment variable or IAM role will be used | +| aws_secret_access_key | String | No | S3 Secret Access Key. If not provided, `AWS_SECRET_ACCESS_KEY` environment variable or IAM role will be used | +| aws_session_token | String | No | AWS Session token. If not provided, `AWS_SESSION_TOKEN` environment variable will be used | +| aws_profile | String | No | AWS profile name for profile based authentication. If not provided, `AWS_PROFILE` environment variable will be used. | | s3_bucket | String | Yes | S3 Bucket name | | s3_key_prefix | String | No | (Default: None) A static prefix before the generated S3 key names. Using prefixes you can upload files into specific directories in the S3 bucket. | +| s3_endpoint_url | String | No | The complete URL to use for the constructed client. This is allowing to use non-native s3 account. | +| s3_region_name | String | No | Default region when creating new connections | +| s3_acl | String | No | S3 ACL name to set on the uploaded files | | stage | String | Yes | Named external stage name created at pre-requirements section. Has to be a fully qualified name including the schema name | | file_format | String | Yes | Named file format name created at pre-requirements section. Has to be a fully qualified name including the schema name. | | batch_size_rows | Integer | | (Default: 100000) Maximum number of rows in each batch. At the end of each batch, the rows in the batch are loaded into Snowflake. | diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index 02896b28..36c636b5 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -288,12 +288,33 @@ def __init__(self, connection_config, stream_schema_message=None, table_cache=No self.data_flattening_max_level = self.connection_config.get('data_flattening_max_level', 0) self.flatten_schema = flatten_schema(stream_schema_message['schema'], max_level=self.data_flattening_max_level) - self.s3 = boto3.client( - 's3', - aws_access_key_id=self.connection_config.get('aws_access_key_id'), - aws_secret_access_key=self.connection_config.get('aws_secret_access_key'), - aws_session_token=self.connection_config.get('aws_session_token') - ) + self.s3 = self.create_s3_client() + + def create_s3_client(self, config=None): + if not config: + config = self.connection_config + + # Get the required parameters from config file and/or environment variables + aws_profile = config.get('aws_profile') or os.environ.get('AWS_PROFILE') + aws_access_key_id = config.get('aws_access_key_id') or os.environ.get('AWS_ACCESS_KEY_ID') + aws_secret_access_key = config.get('aws_secret_access_key') or os.environ.get('AWS_SECRET_ACCESS_KEY') + aws_session_token = config.get('aws_session_token') or os.environ.get('AWS_SESSION_TOKEN') + + # AWS credentials based authentication + if aws_access_key_id and aws_secret_access_key: + aws_session = boto3.session.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token + ) + # AWS Profile based authentication + else: + aws_session = boto3.session.Session(profile_name=aws_profile) + + # Create the s3 client + return aws_session.client('s3', + region_name=config.get('s3_region_name'), + endpoint_url=config.get('s3_endpoint_url')) def open_connection(self): return snowflake.connector.connect( diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index 4f6026e9..0866ae5b 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -3,6 +3,7 @@ import unittest import mock import os +import botocore from nose.tools import assert_raises @@ -915,24 +916,89 @@ def test_loading_tables_with_custom_temp_dir(self): self.assert_three_streams_are_into_snowflake() - def test_using_aws_environment_variables(self): - """Test loading data with aws in the environment rather than explicitly provided access keys""" + def test_aws_env_vars(self): + """Test loading data with credentials defined in AWS environment variables + than explicitly provided access keys""" tap_lines = test_utils.get_test_tap_lines("messages-with-three-streams.json") try: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "TARGET_SNOWFLAKE_AWS_ACCESS_KEY" - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "TARGET_SNOWFLAKE_AWS_SECRET_ACCESS_KEY" - ) - self.config["aws_access_key_id"] = None - self.config["aws_secret_access_key"] = None - - target_snowflake.persist_lines(self.config, tap_lines) + # Save original config to restore later + orig_config = self.config.copy() + + # Move aws access key and secret from config into environment variables + os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('TARGET_SNOWFLAKE_AWS_ACCESS_KEY') + os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('TARGET_SNOWFLAKE_AWS_SECRET_ACCESS_KEY') + del self.config['aws_access_key_id'] + del self.config['aws_secret_access_key'] + + # Create a new S3 client using env vars + snowflake = DbSync(self.config) + snowflake.create_s3_client() + + # Restore the original state to not confuse other tests + finally: + del os.environ['AWS_ACCESS_KEY_ID'] + del os.environ['AWS_SECRET_ACCESS_KEY'] + self.config = orig_config.copy() + + def test_profile_based_auth(self): + """Test AWS profile based authentication rather than access keys""" + try: + # Save original config to restore later + orig_config = self.config.copy() + + # Remove access keys from config and add profile name + del self.config['aws_access_key_id'] + del self.config['aws_secret_access_key'] + self.config['aws_profile'] = 'fake-profile' + + # Create a new S3 client using profile based authentication + with assert_raises(botocore.exceptions.ProfileNotFound): + snowflake = DbSync(self.config) + snowflake.create_s3_client() + + # Restore the original state to not confuse other tests + finally: + self.config = orig_config.copy() + + def test_profile_based_auth_aws_env_var(self): + """Test AWS profile based authentication using AWS environment variables""" + try: + # Save original config to restore later + orig_config = self.config.copy() + + # Remove access keys from config and add profile name environment variable + del self.config['aws_access_key_id'] + del self.config['aws_secret_access_key'] + os.environ['AWS_PROFILE'] = 'fake_profile' + + # Create a new S3 client using profile based authentication + with assert_raises(botocore.exceptions.ProfileNotFound): + snowflake = DbSync(self.config) + snowflake.create_s3_client() + + # Restore the original state to not confuse other tests + finally: + del os.environ['AWS_PROFILE'] + self.config = orig_config.copy() + + def test_s3_custom_endpoint_url(self): + """Test S3 connection with custom region and endpoint URL""" + try: + # Save original config to restore later + orig_config = self.config.copy() + + # Define custom S3 endpoint + self.config['s3_endpoint_url'] = 'fake-endpoint-url' + + # Botocore should raise ValurError in case of fake S3 endpoint url + with assert_raises(ValueError): + snowflake = DbSync(self.config) + snowflake.create_s3_client() + + # Restore the original state to not confuse other tests finally: - del os.environ["AWS_ACCESS_KEY_ID"] - del os.environ["AWS_SECRET_ACCESS_KEY"] + self.config = orig_config.copy() def test_too_many_records_exception(self): """Test if query function raise exception if max_records exceeded"""