Skip to content

Commit a6faddd

Browse files
committed
refactor(generator): Enable use of generators without engine
1 parent 2a60532 commit a6faddd

9 files changed

+57
-60
lines changed

pyproject.toml

-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ dynamic = ["version"]
4343
test = [
4444
"pytest >= 7.4",
4545
"coverage >= 7",
46-
"psycopg2-binary",
47-
"mysql-connector-python",
4846
]
4947
sqlmodel = ["sqlmodel >= 0.0.12"]
5048
citext = ["sqlalchemy-citext >= 1.7.0"]

src/sqlacodegen/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def main() -> None:
8686

8787
# Instantiate the generator
8888
generator_class = generators[args.generator].load()
89-
generator = generator_class(metadata, engine, options)
89+
generator = generator_class(metadata, engine.dialect, options)
9090

9191
# Open the target file (if given)
9292
with ExitStack() as stack:

src/sqlacodegen/generators.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Computed,
2626
Constraint,
2727
DefaultClause,
28+
Dialect,
2829
Enum,
2930
Float,
3031
ForeignKey,
@@ -39,7 +40,6 @@
3940
UniqueConstraint,
4041
)
4142
from sqlalchemy.dialects.postgresql import JSONB
42-
from sqlalchemy.engine import Connection, Engine
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
4545

@@ -94,11 +94,9 @@ class Base:
9494
class CodeGenerator(metaclass=ABCMeta):
9595
valid_options: ClassVar[set[str]] = set()
9696

97-
def __init__(
98-
self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
99-
):
97+
def __init__(self, metadata: MetaData, dialect: Dialect, options: Sequence[str]):
10098
self.metadata: MetaData = metadata
101-
self.bind: Connection | Engine = bind
99+
self.dialect: Dialect = dialect
102100
self.options: set[str] = set(options)
103101

104102
# Validate options
@@ -124,12 +122,12 @@ class TablesGenerator(CodeGenerator):
124122
def __init__(
125123
self,
126124
metadata: MetaData,
127-
bind: Connection | Engine,
125+
dialect: Dialect,
128126
options: Sequence[str],
129127
*,
130128
indentation: str = " ",
131129
):
132-
super().__init__(metadata, bind, options)
130+
super().__init__(metadata, dialect, options)
133131
self.indentation: str = indentation
134132
self.imports: dict[str, set[str]] = defaultdict(set)
135133
self.module_imports: set[str] = set()
@@ -562,7 +560,7 @@ def add_fk_options(*opts: Any) -> None:
562560
]
563561
add_fk_options(local_columns, remote_columns)
564562
elif isinstance(constraint, CheckConstraint):
565-
args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
563+
args.append(repr(get_compiled_expression(constraint.sqltext, self.dialect)))
566564
elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
567565
args.extend(repr(col.name) for col in constraint.columns)
568566
else:
@@ -608,7 +606,7 @@ def fix_column_types(self, table: Table) -> None:
608606
# Detect check constraints for boolean and enum columns
609607
for constraint in table.constraints.copy():
610608
if isinstance(constraint, CheckConstraint):
611-
sqltext = get_compiled_expression(constraint.sqltext, self.bind)
609+
sqltext = get_compiled_expression(constraint.sqltext, self.dialect)
612610

613611
# Turn any integer-like column with a CheckConstraint like
614612
# "column IN (0, 1)" into a Boolean
@@ -646,7 +644,7 @@ def fix_column_types(self, table: Table) -> None:
646644
pass
647645

648646
# PostgreSQL specific fix: detect sequences from server_default
649-
if column.server_default and self.bind.dialect.name == "postgresql":
647+
if column.server_default and self.dialect.name == "postgresql":
650648
if isinstance(column.server_default, DefaultClause) and isinstance(
651649
column.server_default.arg, TextClause
652650
):
@@ -661,7 +659,7 @@ def fix_column_types(self, table: Table) -> None:
661659
column.server_default = None
662660

663661
def get_adapted_type(self, coltype: Any) -> Any:
664-
compiled_type = coltype.compile(self.bind.engine.dialect)
662+
compiled_type = coltype.compile(self.dialect)
665663
for supercls in coltype.__class__.__mro__:
666664
if not supercls.__name__.startswith("_") and hasattr(
667665
supercls, "__visit_name__"
@@ -687,7 +685,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
687685
try:
688686
# If the adapted column type does not render the same as the
689687
# original, don't substitute it
690-
if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
688+
if new_coltype.compile(self.dialect) != compiled_type:
691689
# Make an exception to the rule for Float and arrays of Float,
692690
# since at least on PostgreSQL, Float can accurately represent
693691
# both REAL and DOUBLE_PRECISION
@@ -718,13 +716,13 @@ class DeclarativeGenerator(TablesGenerator):
718716
def __init__(
719717
self,
720718
metadata: MetaData,
721-
bind: Connection | Engine,
719+
dialect: Dialect,
722720
options: Sequence[str],
723721
*,
724722
indentation: str = " ",
725723
base_class_name: str = "Base",
726724
):
727-
super().__init__(metadata, bind, options, indentation=indentation)
725+
super().__init__(metadata, dialect, options, indentation=indentation)
728726
self.base_class_name: str = base_class_name
729727
self.inflect_engine = inflect.engine()
730728

@@ -1305,7 +1303,7 @@ class DataclassGenerator(DeclarativeGenerator):
13051303
def __init__(
13061304
self,
13071305
metadata: MetaData,
1308-
bind: Connection | Engine,
1306+
dialect: Dialect,
13091307
options: Sequence[str],
13101308
*,
13111309
indentation: str = " ",
@@ -1315,7 +1313,7 @@ def __init__(
13151313
):
13161314
super().__init__(
13171315
metadata,
1318-
bind,
1316+
dialect,
13191317
options,
13201318
indentation=indentation,
13211319
base_class_name=base_class_name,
@@ -1344,15 +1342,15 @@ class SQLModelGenerator(DeclarativeGenerator):
13441342
def __init__(
13451343
self,
13461344
metadata: MetaData,
1347-
bind: Connection | Engine,
1345+
dialect: Dialect,
13481346
options: Sequence[str],
13491347
*,
13501348
indentation: str = " ",
13511349
base_class_name: str = "SQLModel",
13521350
):
13531351
super().__init__(
13541352
metadata,
1355-
bind,
1353+
dialect,
13561354
options,
13571355
indentation=indentation,
13581356
base_class_name=base_class_name,

src/sqlacodegen/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from collections.abc import Mapping
55
from typing import Any
66

7-
from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
8-
from sqlalchemy.engine import Connection, Engine
7+
from sqlalchemy import Dialect, PrimaryKeyConstraint, UniqueConstraint
98
from sqlalchemy.sql import ClauseElement
109
from sqlalchemy.sql.elements import TextClause
1110
from sqlalchemy.sql.schema import (
@@ -34,9 +33,11 @@ def get_constraint_sort_key(constraint: Constraint) -> str:
3433
return str(constraint)
3534

3635

37-
def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
36+
def get_compiled_expression(statement: ClauseElement, dialect: Dialect) -> str:
3837
"""Return the statement in a form where any placeholders have been filled in."""
39-
return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
38+
return str(
39+
statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
40+
)
4041

4142

4243
def get_common_fk_constraints(

tests/conftest.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22

33
import pytest
44
from pytest import FixtureRequest
5-
from sqlalchemy.engine import Engine, create_engine
5+
from sqlalchemy import Dialect
6+
from sqlalchemy.dialects import mysql, postgresql, sqlite
67
from sqlalchemy.orm import clear_mappers, configure_mappers
78
from sqlalchemy.schema import MetaData
89

910

1011
@pytest.fixture
11-
def engine(request: FixtureRequest) -> Engine:
12-
dialect = getattr(request, "param", None)
13-
if dialect == "postgresql":
14-
return create_engine("postgresql:///testdb")
15-
elif dialect == "mysql":
16-
return create_engine("mysql+mysqlconnector://testdb")
12+
def dialect(request: FixtureRequest) -> Dialect:
13+
dialect_name = getattr(request, "param", None)
14+
if dialect_name == "postgresql":
15+
return postgresql.dialect()
16+
elif dialect_name == "mysql":
17+
return mysql.mysqlconnector.dialect()
1718
else:
18-
return create_engine("sqlite:///:memory:")
19+
return sqlite.dialect()
1920

2021

2122
@pytest.fixture

tests/test_generator_dataclass.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5+
from sqlalchemy import Dialect
56
from sqlalchemy.dialects.postgresql import UUID
6-
from sqlalchemy.engine import Engine
77
from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
88
from sqlalchemy.sql.expression import text
99
from sqlalchemy.types import INTEGER, VARCHAR
@@ -15,10 +15,10 @@
1515

1616
@pytest.fixture
1717
def generator(
18-
request: FixtureRequest, metadata: MetaData, engine: Engine
18+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
1919
) -> CodeGenerator:
2020
options = getattr(request, "param", [])
21-
return DataclassGenerator(metadata, engine, options)
21+
return DataclassGenerator(metadata, dialect, options)
2222

2323

2424
def test_basic_class(generator: CodeGenerator) -> None:

tests/test_generator_declarative.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy import PrimaryKeyConstraint
6-
from sqlalchemy.engine import Engine
5+
from sqlalchemy import Dialect, PrimaryKeyConstraint
76
from sqlalchemy.schema import (
87
CheckConstraint,
98
Column,
@@ -24,10 +23,10 @@
2423

2524
@pytest.fixture
2625
def generator(
27-
request: FixtureRequest, metadata: MetaData, engine: Engine
26+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
2827
) -> CodeGenerator:
2928
options = getattr(request, "param", [])
30-
return DeclarativeGenerator(metadata, engine, options)
29+
return DeclarativeGenerator(metadata, dialect, options)
3130

3231

3332
def test_indexes(generator: CodeGenerator) -> None:

tests/test_generator_sqlmodel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy.engine import Engine
5+
from sqlalchemy import Dialect
66
from sqlalchemy.schema import (
77
CheckConstraint,
88
Column,
@@ -21,10 +21,10 @@
2121

2222
@pytest.fixture
2323
def generator(
24-
request: FixtureRequest, metadata: MetaData, engine: Engine
24+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
2525
) -> CodeGenerator:
2626
options = getattr(request, "param", [])
27-
return SQLModelGenerator(metadata, engine, options)
27+
return SQLModelGenerator(metadata, dialect, options)
2828

2929

3030
def test_indexes(generator: CodeGenerator) -> None:

0 commit comments

Comments
 (0)