Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Draft: Make 'file_format' optional #219

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ __pycache__/
*~
dist/
.coverage
.mypy_cache

# Singer JSON files
properties.json
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ You need to create a few objects in snowflake in one schema before start using t

1. Create a named file format. This will be used by the MERGE/COPY commands to parse the files correctly from S3. You can use CSV or Parquet file formats.

To use CSV files:
To use the default (CSV) file format option:

Leave `file_format` blank in settings to use the default CSV format options.

To use a named CSV file format:

```
CREATE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' ESCAPE='\\' FIELD_OPTIONALLY_ENCLOSED_BY='"';
Expand Down
15 changes: 7 additions & 8 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from singer import get_logger
from target_snowflake import flattening
from target_snowflake import stream_utils
from target_snowflake.file_format import FileFormat, FileFormatTypes
from target_snowflake.file_format import FileFormat, FileFormatTypes, InlineFileFormat

from target_snowflake.exceptions import TooManyRecordsException, PrimaryKeyNotFoundException
from target_snowflake.upload_clients.s3_upload_client import S3UploadClient
Expand All @@ -27,7 +27,6 @@ def validate_config(config):
'warehouse',
's3_bucket',
'stage',
'file_format'
]

snowflake_required_config_keys = [
Expand All @@ -36,7 +35,6 @@ def validate_config(config):
'user',
'password',
'warehouse',
'file_format'
]

required_config_keys = []
Expand Down Expand Up @@ -212,7 +210,10 @@ def __init__(self, connection_config, stream_schema_message=None, table_cache=No

self.schema_name = None
self.grantees = None
self.file_format = FileFormat(self.connection_config['file_format'], self.query, file_format_type)
if 'file_format' in self.connection_config:
self.file_format = FileFormat(self.connection_config['file_format'], self.query, file_format_type)
else:
self.file_format = InlineFileFormat(file_format_type or FileFormatTypes.CSV)

if not self.connection_config.get('stage') and self.file_format.file_format_type == FileFormatTypes.PARQUET:
self.logger.error("Table stages with Parquet file format is not suppported. "
Expand Down Expand Up @@ -469,8 +470,7 @@ def load_file(self, s3_key, count, size_bytes):
merge_sql = self.file_format.formatter.create_merge_sql(table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_key=s3_key,
file_format_name=
self.connection_config['file_format'],
file_format=self.file_format,
columns=columns_with_trans,
pk_merge_condition=
self.primary_key_merge_condition())
Expand All @@ -488,8 +488,7 @@ def load_file(self, s3_key, count, size_bytes):
copy_sql = self.file_format.formatter.create_copy_sql(table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_key=s3_key,
file_format_name=
self.connection_config['file_format'],
file_format=self.file_format,
columns=columns_with_trans)
self.logger.debug('Running query: %s', copy_sql)
cur.execute(copy_sql)
Expand Down
96 changes: 82 additions & 14 deletions target_snowflake/file_format.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Enums used by pipelinewise-target-snowflake"""
import abc
from enum import Enum, unique
from types import ModuleType
from typing import Callable
Expand All @@ -20,20 +21,13 @@ def list():
return list(map(lambda c: c.value, FileFormatTypes))


# pylint: disable=too-few-public-methods
class FileFormat:
"""File Format class"""
class FileFormat(abc.ABCMeta):
"""File Format class (abstract)"""

def __init__(self, file_format: str, query_fn: Callable, file_format_type: FileFormatTypes=None):
"""Find the file format in Snowflake, detect its type and
initialise file format specific functions"""
if file_format_type:
self.file_format_type = file_format_type
else:
# Detect file format type by querying it from Snowflake
self.file_format_type = self._detect_file_format_type(file_format, query_fn)

self.formatter = self._get_formatter(self.file_format_type)
def __init__(self, file_format_type: FileFormatTypes):
"""Initialize type and file format specific functions."""
self.file_format_type = file_format_type
self.formatter = self._get_formatter(file_format_type)

@classmethod
def _get_formatter(cls, file_format_type: FileFormatTypes) -> ModuleType:
Expand All @@ -44,7 +38,7 @@ def _get_formatter(cls, file_format_type: FileFormatTypes) -> ModuleType:
file_format_type: FileFormatTypes enum item

Returns:
ModuleType implementation of the file ormatter
ModuleType implementation of the file formatter
"""
formatter = None

Expand All @@ -57,6 +51,36 @@ def _get_formatter(cls, file_format_type: FileFormatTypes) -> ModuleType:

return formatter

@abc.abstractproperty
def declaration_for_copy(self) -> str:
"""Return the format declaration text for a COPY INTO statement."""
pass

@abc.abstractproperty
def declaration_for_merge(self) -> str:
"""Return the format declaration text for a MERGE statement."""
pass


# pylint: disable=too-few-public-methods
class NamedFileFormat(FileFormat):
"""Named File Format class"""

def __init__(
self,
file_format: str,
query_fn: Callable,
file_format_type: FileFormatTypes = None,
):
"""Find the file format in Snowflake, detect its type and
initialise file format specific functions"""
self.qualified_format_name = file_format
if not file_format_type:
# Detect file format type by querying it from Snowflake
file_format_type = self._detect_file_format_type(file_format, query_fn)

super().__init__(file_format_type)

@classmethod
def _detect_file_format_type(cls, file_format: str, query_fn: Callable) -> FileFormatTypes:
"""Detect the type of an existing snowflake file format object
Expand Down Expand Up @@ -84,3 +108,47 @@ def _detect_file_format_type(cls, file_format: str, query_fn: Callable) -> FileF
f"Named file format not found: {file_format}")

return file_format_type

def declaration_for_copy(self) -> str:
"""Return the format declaration text for a COPY INTO statement."""
return f"FILE_FORMAT = (format_name='{self.qualified_format_name}')"

def declaration_for_merge(self) -> str:
return f"FILE_FORMAT => '{self.qualified_format_name}'"


class InlineFileFormat(FileFormat):
def __init__(
self,
file_format_type: FileFormatTypes = None,
):
"""Find the file format in Snowflake, detect its type and
initialise file format specific functions"""
if file_format_type != FileFormatTypes.CSV:
raise NotImplementedError("Only CSV is supported as an inline format type.")

self.file_format_type = file_format_type
self.formatter = self._get_formatter(self.file_format_type)

@abc.abstractproperty
def declaration_for_copy(self) -> str:
"""Return the format declaration text for a COPY INTO statement."""
if self.file_format_type == FileFormatTypes.CSV:
return (
"FILE_FORMAT = (\n"
" TYPE = CSV\n"
" EMPTY_FIELD_AS_NULL = FALSE\n"
" FIELD_OPTIONALLY_ENCLOSED_BY = '\"'\n"
")\n"
)

raise NotImplementedError("Only CSV is supported as an inline format type.")

def declaration_for_merge(self) -> str:
return (
"FILE_FORMAT => (\n"
" TYPE = CSV\n"
" EMPTY_FIELD_AS_NULL = FALSE\n"
" FIELD_OPTIONALLY_ENCLOSED_BY = '\"'\n"
")"
)
10 changes: 5 additions & 5 deletions target_snowflake/file_formats/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@
from tempfile import mkstemp

from target_snowflake import flattening

from target_snowflake.file_formats import FileFormat

def create_copy_sql(table_name: str,
stage_name: str,
s3_key: str,
file_format_name: str,
file_format: FileFormat,
columns: List):
"""Generate a CSV compatible snowflake COPY INTO command"""
p_columns = ', '.join([c['name'] for c in columns])

return f"COPY INTO {table_name} ({p_columns}) " \
f"FROM '@{stage_name}/{s3_key}' " \
f"FILE_FORMAT = (format_name='{file_format_name}')"
f"{file_format.declaration_for_copy}"


def create_merge_sql(table_name: str,
stage_name: str,
s3_key: str,
file_format_name: str,
file_format: FileFormat,
columns: List,
pk_merge_condition: str) -> str:
"""Generate a CSV compatible snowflake MERGE INTO command"""
Expand All @@ -37,7 +37,7 @@ def create_merge_sql(table_name: str,
return f"MERGE INTO {table_name} t USING (" \
f"SELECT {p_source_columns} " \
f"FROM '@{stage_name}/{s3_key}' " \
f"(FILE_FORMAT => '{file_format_name}')) s " \
f"({file_format.declaration_for_merge})) s " \
f"ON {pk_merge_condition} " \
f"WHEN MATCHED THEN UPDATE SET {p_update} " \
"WHEN NOT MATCHED THEN " \
Expand Down
9 changes: 5 additions & 4 deletions target_snowflake/file_formats/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from tempfile import mkstemp

from target_snowflake import flattening
from target_snowflake.file_formats import FileFormat


def create_copy_sql(table_name: str,
stage_name: str,
s3_key: str,
file_format_name: str,
file_format: FileFormat,
columns: List):
"""Generate a Parquet compatible snowflake COPY INTO command"""
p_target_columns = ', '.join([c['name'] for c in columns])
Expand All @@ -20,13 +21,13 @@ def create_copy_sql(table_name: str,

return f"COPY INTO {table_name} ({p_target_columns}) " \
f"FROM (SELECT {p_source_columns} FROM '@{stage_name}/{s3_key}') " \
f"FILE_FORMAT = (format_name='{file_format_name}')"
f"{file_format.declaration_for_copy}"


def create_merge_sql(table_name: str,
stage_name: str,
s3_key: str,
file_format_name: str,
file_format: FileFormat,
columns: List,
pk_merge_condition: str) -> str:
"""Generate a Parquet compatible snowflake MERGE INTO command"""
Expand All @@ -39,7 +40,7 @@ def create_merge_sql(table_name: str,
return f"MERGE INTO {table_name} t USING (" \
f"SELECT {p_source_columns} " \
f"FROM '@{stage_name}/{s3_key}' " \
f"(FILE_FORMAT => '{file_format_name}')) s " \
f"({file_format.declaration_for_merge})) s " \
f"ON {pk_merge_condition} " \
f"WHEN MATCHED THEN UPDATE SET {p_update} " \
"WHEN NOT MATCHED THEN " \
Expand Down