Skip to content

Commit

Permalink
chore: Enforce importing sqlalchemy as sa (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon authored Jan 23, 2024
1 parent 78a1063 commit 402d023
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 112 deletions.
26 changes: 18 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,25 @@ vcs = "git"
style = "semver"

[tool.ruff]
target-version = "py38"

[tool.ruff.lint]
select = [
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"ICN", # flake8-import-conventions
"RUF", # ruff
]
target-version = "py38"

[tool.ruff.pydocstyle]
[tool.ruff.lint.flake8-import-conventions]
banned-from = ["sqlalchemy"]

[tool.ruff.lint.flake8-import-conventions.extend-aliases]
sqlalchemy = "sa"

[tool.ruff.lint.pydocstyle]
convention = "google"
114 changes: 54 additions & 60 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import paramiko
import simplejson
import sqlalchemy
import sqlalchemy as sa
from singer_sdk import SQLConnector
from singer_sdk import typing as th
from sqlalchemy.dialects.postgresql import ARRAY, BIGINT, JSONB
Expand Down Expand Up @@ -84,10 +84,10 @@ def prepare_table( # type: ignore[override]
full_table_name: str,
schema: dict,
primary_keys: list[str],
connection: sqlalchemy.engine.Connection,
connection: sa.engine.Connection,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> sqlalchemy.Table:
) -> sa.Table:
"""Adapt target table to provided schema if possible.
Args:
Expand All @@ -102,8 +102,8 @@ def prepare_table( # type: ignore[override]
The table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
table: sqlalchemy.Table
meta = sa.MetaData(schema=schema_name)
table: sa.Table
if not self.table_exists(full_table_name=full_table_name):
table = self.create_empty_table(
table_name=table_name,
Expand Down Expand Up @@ -143,10 +143,10 @@ def prepare_table( # type: ignore[override]
def copy_table_structure(
self,
full_table_name: str,
from_table: sqlalchemy.Table,
connection: sqlalchemy.engine.Connection,
from_table: sa.Table,
connection: sa.engine.Connection,
as_temp_table: bool = False,
) -> sqlalchemy.Table:
) -> sa.Table:
"""Copy table structure.
Args:
Expand All @@ -159,58 +159,54 @@ def copy_table_structure(
The new table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
new_table: sqlalchemy.Table
meta = sa.MetaData(schema=schema_name)
new_table: sa.Table
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
for column in from_table.columns:
columns.append(column._copy())
if as_temp_table:
new_table = sqlalchemy.Table(
table_name, meta, *columns, prefixes=["TEMPORARY"]
)
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
new_table.create(bind=connection)
return new_table
else:
new_table = sqlalchemy.Table(table_name, meta, *columns)
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

@contextmanager
def _connect(self) -> t.Iterator[sqlalchemy.engine.Connection]:
def _connect(self) -> t.Iterator[sa.engine.Connection]:
with self._engine.connect().execution_options() as conn:
yield conn

def drop_table(
self, table: sqlalchemy.Table, connection: sqlalchemy.engine.Connection
):
def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
"""Drop table data."""
table.drop(bind=connection)

def clone_table(
self, new_table_name, table, metadata, connection, temp_table
) -> sqlalchemy.Table:
) -> sa.Table:
"""Clone a table."""
new_columns = []
for column in table.columns:
new_columns.append(
sqlalchemy.Column(
sa.Column(
column.name,
column.type,
)
)
if temp_table is True:
new_table = sqlalchemy.Table(
new_table = sa.Table(
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
)
else:
new_table = sqlalchemy.Table(new_table_name, metadata, *new_columns)
new_table = sa.Table(new_table_name, metadata, *new_columns)
new_table.create(bind=connection)
return new_table

@staticmethod
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""Return a JSON Schema representation of the provided type.
By default will call `typing.to_sql_type()`.
Expand Down Expand Up @@ -317,13 +313,13 @@ def pick_best_sql_type(sql_type_array: list):
def create_empty_table( # type: ignore[override]
self,
table_name: str,
meta: sqlalchemy.MetaData,
meta: sa.MetaData,
schema: dict,
connection: sqlalchemy.engine.Connection,
connection: sa.engine.Connection,
primary_keys: list[str] | None = None,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> sqlalchemy.Table:
) -> sa.Table:
"""Create an empty target table.
Args:
Expand All @@ -342,7 +338,7 @@ def create_empty_table( # type: ignore[override]
NotImplementedError: if temp tables are unsupported and as_temp_table=True.
RuntimeError: if a variant schema is passed with no properties defined.
"""
columns: list[sqlalchemy.Column] = []
columns: list[sa.Column] = []
primary_keys = primary_keys or []
try:
properties: dict = schema["properties"]
Expand All @@ -355,31 +351,29 @@ def create_empty_table( # type: ignore[override]
for property_name, property_jsonschema in properties.items():
is_primary_key = property_name in primary_keys
columns.append(
sqlalchemy.Column(
sa.Column(
property_name,
self.to_sql_type(property_jsonschema),
primary_key=is_primary_key,
autoincrement=False, # See: https://github.com/MeltanoLabs/target-postgres/issues/193 # noqa: E501
)
)
if as_temp_table:
new_table = sqlalchemy.Table(
table_name, meta, *columns, prefixes=["TEMPORARY"]
)
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
new_table.create(bind=connection)
return new_table

new_table = sqlalchemy.Table(table_name, meta, *columns)
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

def prepare_column(
self,
full_table_name: str,
column_name: str,
sql_type: sqlalchemy.types.TypeEngine,
connection: sqlalchemy.engine.Connection | None = None,
column_object: sqlalchemy.Column | None = None,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection | None = None,
column_object: sa.Column | None = None,
) -> None:
"""Adapt target table to provided schema if possible.
Expand All @@ -402,7 +396,7 @@ def prepare_column(

if not column_exists:
self._create_empty_column(
# We should migrate every function to use sqlalchemy.Table
# We should migrate every function to use sa.Table
# instead of having to know what the function wants
table_name=table_name,
column_name=column_name,
Expand All @@ -426,8 +420,8 @@ def _create_empty_column( # type: ignore[override]
schema_name: str,
table_name: str,
column_name: str,
sql_type: sqlalchemy.types.TypeEngine,
connection: sqlalchemy.engine.Connection,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
) -> None:
"""Create a new column.
Expand Down Expand Up @@ -458,8 +452,8 @@ def get_column_add_ddl( # type: ignore[override]
table_name: str,
schema_name: str,
column_name: str,
column_type: sqlalchemy.types.TypeEngine,
) -> sqlalchemy.DDL:
column_type: sa.types.TypeEngine,
) -> sa.DDL:
"""Get the create column DDL statement.
Args:
Expand All @@ -471,9 +465,9 @@ def get_column_add_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
column = sqlalchemy.Column(column_name, column_type)
column = sa.Column(column_name, column_type)

return sqlalchemy.DDL(
return sa.DDL(
(
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
"ADD COLUMN %(column_name)s %(column_type)s"
Expand All @@ -491,9 +485,9 @@ def _adapt_column_type( # type: ignore[override]
schema_name: str,
table_name: str,
column_name: str,
sql_type: sqlalchemy.types.TypeEngine,
connection: sqlalchemy.engine.Connection,
column_object: sqlalchemy.Column | None,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
column_object: sa.Column | None,
) -> None:
"""Adapt table column type to support the new JSON schema type.
Expand All @@ -508,9 +502,9 @@ def _adapt_column_type( # type: ignore[override]
Raises:
NotImplementedError: if altering columns is not supported.
"""
current_type: sqlalchemy.types.TypeEngine
current_type: sa.types.TypeEngine
if column_object is not None:
current_type = t.cast(sqlalchemy.types.TypeEngine, column_object.type)
current_type = t.cast(sa.types.TypeEngine, column_object.type)
else:
current_type = self._get_column_type(
schema_name=schema_name,
Expand Down Expand Up @@ -561,8 +555,8 @@ def get_column_alter_ddl( # type: ignore[override]
schema_name: str,
table_name: str,
column_name: str,
column_type: sqlalchemy.types.TypeEngine,
) -> sqlalchemy.DDL:
column_type: sa.types.TypeEngine,
) -> sa.DDL:
"""Get the alter column DDL statement.
Override this if your database uses a different syntax for altering columns.
Expand All @@ -576,8 +570,8 @@ def get_column_alter_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
column = sqlalchemy.Column(column_name, column_type)
return sqlalchemy.DDL(
column = sa.Column(column_name, column_type)
return sa.DDL(
(
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
"ALTER COLUMN %(column_name)s %(column_type)s"
Expand Down Expand Up @@ -700,7 +694,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
paramiko.Ed25519Key,
):
try:
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined] # noqa: E501
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
except paramiko.SSHException:
continue
else:
Expand Down Expand Up @@ -728,8 +722,8 @@ def _get_column_type( # type: ignore[override]
schema_name: str,
table_name: str,
column_name: str,
connection: sqlalchemy.engine.Connection,
) -> sqlalchemy.types.TypeEngine:
connection: sa.engine.Connection,
) -> sa.types.TypeEngine:
"""Get the SQL type of the declared column.
Args:
Expand Down Expand Up @@ -757,15 +751,15 @@ def _get_column_type( # type: ignore[override]
)
raise KeyError(msg) from ex

return t.cast(sqlalchemy.types.TypeEngine, column.type)
return t.cast(sa.types.TypeEngine, column.type)

def get_table_columns( # type: ignore[override]
self,
schema_name: str,
table_name: str,
connection: sqlalchemy.engine.Connection,
connection: sa.engine.Connection,
column_names: list[str] | None = None,
) -> dict[str, sqlalchemy.Column]:
) -> dict[str, sa.Column]:
"""Return a list of table columns.
Overrode to support schema_name
Expand All @@ -779,11 +773,11 @@ def get_table_columns( # type: ignore[override]
Returns:
An ordered list of column objects.
"""
inspector = sqlalchemy.inspect(connection)
inspector = sa.inspect(connection)
columns = inspector.get_columns(table_name, schema_name)

return {
col_meta["name"]: sqlalchemy.Column(
col_meta["name"]: sa.Column(
col_meta["name"],
col_meta["type"],
nullable=col_meta.get("nullable", False),
Expand All @@ -797,7 +791,7 @@ def column_exists( # type: ignore[override]
self,
full_table_name: str,
column_name: str,
connection: sqlalchemy.engine.Connection,
connection: sa.engine.Connection,
) -> bool:
"""Determine if the target column already exists.
Expand Down
Loading

0 comments on commit 402d023

Please sign in to comment.