Skip to content

Commit

Permalink
SQLAlchemy v1.4.x support (sqlalchemy-redshift#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
att14 authored Jul 6, 2021
1 parent 60b4db0 commit 4ecb81d
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ sudo: false
language: python
python:
- "2.7"
- "3.4"
- "3.5"
- "3.6"
- "3.7"
- "3.8"
- "3.9"

env:
global:
Expand Down
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
0.8.3 (unreleased)
------------------

- Nothing changed yet.
- SQLAlchemy 1.4.x support


0.8.2 (2021-01-08)
Expand Down
34 changes: 23 additions & 11 deletions sqlalchemy_redshift/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,13 +659,16 @@ def _get_table_or_view_names(self, relkind, connection, schema=None, **kw):
def _get_column_info(self, *args, **kwargs):
kw = kwargs.copy()
encode = kw.pop('encode', None)
if sa_version < Version('1.2.0'):
# SQLAlchemy 1.2.0 introduced the 'comment' param
del kw['comment']
if sa_version >= Version('1.3.16'):
# SQLAlchemy 1.3.16 introduced generated columns,
# not supported in redshift
kw['generated'] = ''

if sa_version < Version('1.4.0') and 'identity' in kw:
del kw['identity']
elif sa_version >= Version('1.4.0') and 'identity' not in kw:
kw['identity'] = None

column_info = super(RedshiftDialect, self)._get_column_info(
*args,
**kw
Expand Down Expand Up @@ -744,7 +747,7 @@ def _get_all_relation_info(self, connection, **kw):
@reflection.cache
def _get_all_column_info(self, connection, **kw):
all_columns = defaultdict(list)
with connection.contextual_connect() as cc:
with connection.connect() as cc:
result = cc.execute("""
SELECT
n.nspname as "schema",
Expand Down Expand Up @@ -918,17 +921,26 @@ def visit_delete_stmt(element, compiler, **kwargs):
# the tables in the using clause are sorted in the order in
# which they first appear in the where clause.
delete_stmt_table = compiler.process(element.table, asfrom=True, **kwargs)
whereclause_tuple = element.get_children()
if whereclause_tuple:
usingclause_tables = []
whereclause = ' WHERE {clause}'.format(
clause=compiler.process(*whereclause_tuple, **kwargs)
)

if sa_version >= Version('1.4.0'):
if element.whereclause is not None:
clause = compiler.process(element.whereclause, **kwargs)
if clause:
whereclause = ' WHERE {clause}'.format(clause=clause)
else:
whereclause_tuple = element.get_children()
if whereclause_tuple:
whereclause = ' WHERE {clause}'.format(
clause=compiler.process(*whereclause_tuple, **kwargs)
)

if whereclause:
usingclause_tables = []
whereclause_columns = gen_columns_from_children(element)
for col in whereclause_columns:
table = compiler.process(col.table, asfrom=True, **kwargs)
if table != delete_stmt_table and table not in usingclause_tables:
if table != delete_stmt_table and \
table not in usingclause_tables:
usingclause_tables.append(table)
if usingclause_tables:
usingclause = ' USING {clause}'.format(
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
from urllib import parse as urlparse
except:
except ImportError:
import urlparse

import requests
Expand All @@ -31,6 +31,7 @@ def database_name_generator():
count=i,
)


database_name = functools.partial(next, database_name_generator())


Expand All @@ -50,7 +51,10 @@ def _database(self):
conn.execute('CREATE DATABASE {db_name}'.format(db_name=db_name))

dburl = copy.deepcopy(self.engine.url)
dburl.database = db_name
try:
dburl.database = db_name
except AttributeError:
dburl = dburl.set(database=db_name)

try:
yield db.EngineDefinition(
Expand Down
8 changes: 6 additions & 2 deletions tests/rs_sqla_test_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import sqlalchemy as sa

from sqlalchemy import event
from sqlalchemy import DDL
from sqlalchemy.ext import declarative
from sqlalchemy.schema import CreateSchema


Base = declarative.declarative_base()
event.listen(Base.metadata, 'before_create', CreateSchema('other_schema'))
event.listen(
Base.metadata,
'before_create',
DDL('CREATE SCHEMA IF NOT EXISTS other_schema')
)


class Basic(Base):
Expand Down
39 changes: 30 additions & 9 deletions tests/test_column_loading.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
from unittest import TestCase

from packaging.version import Version
import sqlalchemy as sa
from sqlalchemy.types import NullType, VARCHAR

from sqlalchemy_redshift.dialect import RedshiftDialect

sa_version = Version(sa.__version__)


class TestColumnReflection(TestCase):
def test_varchar_as_nulltype(self):
"""
Varchar columns with no length should be considered NullType columns
"""
dialect = RedshiftDialect()
column_info = dialect._get_column_info(
'Null Column',
'character varying', None, False, {}, {}, 'default', 'test column'

null_info = dialect._get_column_info(
name='Null Column',
format_type='character varying',
default=None,
notnull=False,
domains={},
enums=[],
schema='default',
encode='',
comment='test column',
identity=None
)
assert isinstance(column_info['type'], NullType)
column_info_1 = dialect._get_column_info(
'character column',
'character varying(30)', None, False, {}, {}, 'default',
comment='test column'
assert isinstance(null_info['type'], NullType)

varchar_info = dialect._get_column_info(
name='character column',
format_type='character varying(30)',
default=None,
notnull=False,
domains={},
enums=[],
schema='default',
encode='',
comment='test column',
identity=None
)
assert isinstance(column_info_1['type'], VARCHAR)
assert isinstance(varchar_info['type'], VARCHAR)
15 changes: 11 additions & 4 deletions tests/test_delete_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
"""

import sqlalchemy as sa
from packaging.version import Version

from rs_sqla_test_utils.utils import clean, compile_query

sa_version = Version(sa.__version__)


meta = sa.MetaData()

Expand Down Expand Up @@ -116,10 +119,14 @@ def test_delete_stmt_simplewhereclause2():
del_stmt = sa.delete(customers).where(
customers.c.email.endswith('test.com')
)
expected = """
DELETE FROM customers
WHERE customers.email
LIKE '%%' || 'test.com'"""
if sa_version >= Version('1.4.0'):
expected = """
DELETE FROM customers
WHERE (customers.email LIKE '%%' || 'test.com')"""
else:
expected = """
DELETE FROM customers
WHERE customers.email LIKE '%%' || 'test.com'"""
assert clean(compile_query(del_stmt)) == clean(expected)


Expand Down
44 changes: 26 additions & 18 deletions tests/test_reflection_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ def test_view_reflection(redshift_engine):
view_query = "SELECT my_table.col1, my_table.col2 FROM my_table"
view_ddl = "CREATE VIEW my_view AS %s" % view_query
conn = redshift_engine.connect()
conn.execute(table_ddl)
conn.execute(view_ddl)
insp = inspect(redshift_engine)
view_definition = insp.get_view_definition('my_view')
assert(clean(compile_query(view_definition)) == clean(view_query))
view = Table('my_view', MetaData(),
autoload=True, autoload_with=redshift_engine)
assert(len(view.columns) == 2)
try:
conn.execute(table_ddl)
conn.execute(view_ddl)
insp = inspect(redshift_engine)
view_definition = insp.get_view_definition('my_view')
assert(clean(compile_query(view_definition)) == clean(view_query))
view = Table('my_view', MetaData(),
autoload=True, autoload_with=redshift_engine)
assert(len(view.columns) == 2)
finally:
conn.execute('DROP TABLE IF EXISTS my_table CASCADE')
conn.execute('DROP VIEW IF EXISTS my_view CASCADE')


def test_late_binding_view_reflection(redshift_engine):
Expand All @@ -30,13 +34,17 @@ def test_late_binding_view_reflection(redshift_engine):
view_ddl = ("CREATE VIEW my_late_view AS "
"%s WITH NO SCHEMA BINDING" % view_query)
conn = redshift_engine.connect()
conn.execute(table_ddl)
conn.execute(view_ddl)
insp = inspect(redshift_engine)
view_definition = insp.get_view_definition('my_late_view')

# For some reason, Redshift returns the entire DDL for late binding views.
assert(clean(compile_query(view_definition)) == clean(view_ddl))
view = Table('my_late_view', MetaData(),
autoload=True, autoload_with=redshift_engine)
assert(len(view.columns) == 2)
try:
conn.execute(table_ddl)
conn.execute(view_ddl)
insp = inspect(redshift_engine)
view_definition = insp.get_view_definition('my_late_view')

# Redshift returns the entire DDL for late binding views.
assert(clean(compile_query(view_definition)) == clean(view_ddl))
view = Table('my_late_view', MetaData(),
autoload=True, autoload_with=redshift_engine)
assert(len(view.columns) == 2)
finally:
conn.execute('DROP TABLE IF EXISTS my_table CASCADE')
conn.execute('DROP VIEW IF EXISTS my_late_view CASCADE')
17 changes: 11 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
[tox]
envlist = py{27,34,35,36,37}, lint, docs
envlist =
py35-sa13,
py{27,36,37,38,39}-sa{13,14},
lint,
docs

[testenv]
passenv = PGPASSWORD
commands = py.test {posargs}
commands = pytest {posargs}
deps =
requests==2.7.0
psycopg2==2.7.3.2
sqlalchemy==1.3.0
pytest==3.10.1
alembic==1.4.2
packaging==20.4
psycopg2==2.8.6
pytest==3.10.1
requests==2.7.0
sa13: sqlalchemy==1.3.24
sa14: sqlalchemy==1.4.15

[testenv:lint]
deps =
Expand Down

0 comments on commit 4ecb81d

Please sign in to comment.