Skip to content

Commit bc68197

Browse files
committed
refactor(generator): Enable use of generators without engine
1 parent 2a60532 commit bc68197

9 files changed

+55
-58
lines changed

pyproject.toml

-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ dynamic = ["version"]
4343
test = [
4444
"pytest >= 7.4",
4545
"coverage >= 7",
46-
"psycopg2-binary",
47-
"mysql-connector-python",
4846
]
4947
sqlmodel = ["sqlmodel >= 0.0.12"]
5048
citext = ["sqlalchemy-citext >= 1.7.0"]

src/sqlacodegen/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def main() -> None:
8686

8787
# Instantiate the generator
8888
generator_class = generators[args.generator].load()
89-
generator = generator_class(metadata, engine, options)
89+
generator = generator_class(metadata, engine.dialect, options)
9090

9191
# Open the target file (if given)
9292
with ExitStack() as stack:

src/sqlacodegen/generators.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Computed,
2626
Constraint,
2727
DefaultClause,
28+
Dialect,
2829
Enum,
2930
Float,
3031
ForeignKey,
@@ -39,7 +40,6 @@
3940
UniqueConstraint,
4041
)
4142
from sqlalchemy.dialects.postgresql import JSONB
42-
from sqlalchemy.engine import Connection, Engine
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
4545

@@ -95,10 +95,10 @@ class CodeGenerator(metaclass=ABCMeta):
9595
valid_options: ClassVar[set[str]] = set()
9696

9797
def __init__(
98-
self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
98+
self, metadata: MetaData, dialect: Dialect, options: Sequence[str]
9999
):
100100
self.metadata: MetaData = metadata
101-
self.bind: Connection | Engine = bind
101+
self.dialect: Dialect = dialect
102102
self.options: set[str] = set(options)
103103

104104
# Validate options
@@ -124,12 +124,12 @@ class TablesGenerator(CodeGenerator):
124124
def __init__(
125125
self,
126126
metadata: MetaData,
127-
bind: Connection | Engine,
127+
dialect: Dialect,
128128
options: Sequence[str],
129129
*,
130130
indentation: str = " ",
131131
):
132-
super().__init__(metadata, bind, options)
132+
super().__init__(metadata, dialect, options)
133133
self.indentation: str = indentation
134134
self.imports: dict[str, set[str]] = defaultdict(set)
135135
self.module_imports: set[str] = set()
@@ -562,7 +562,7 @@ def add_fk_options(*opts: Any) -> None:
562562
]
563563
add_fk_options(local_columns, remote_columns)
564564
elif isinstance(constraint, CheckConstraint):
565-
args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
565+
args.append(repr(get_compiled_expression(constraint.sqltext, self.dialect)))
566566
elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
567567
args.extend(repr(col.name) for col in constraint.columns)
568568
else:
@@ -608,7 +608,7 @@ def fix_column_types(self, table: Table) -> None:
608608
# Detect check constraints for boolean and enum columns
609609
for constraint in table.constraints.copy():
610610
if isinstance(constraint, CheckConstraint):
611-
sqltext = get_compiled_expression(constraint.sqltext, self.bind)
611+
sqltext = get_compiled_expression(constraint.sqltext, self.dialect)
612612

613613
# Turn any integer-like column with a CheckConstraint like
614614
# "column IN (0, 1)" into a Boolean
@@ -646,7 +646,7 @@ def fix_column_types(self, table: Table) -> None:
646646
pass
647647

648648
# PostgreSQL specific fix: detect sequences from server_default
649-
if column.server_default and self.bind.dialect.name == "postgresql":
649+
if column.server_default and self.dialect.name == "postgresql":
650650
if isinstance(column.server_default, DefaultClause) and isinstance(
651651
column.server_default.arg, TextClause
652652
):
@@ -661,7 +661,7 @@ def fix_column_types(self, table: Table) -> None:
661661
column.server_default = None
662662

663663
def get_adapted_type(self, coltype: Any) -> Any:
664-
compiled_type = coltype.compile(self.bind.engine.dialect)
664+
compiled_type = coltype.compile(self.dialect)
665665
for supercls in coltype.__class__.__mro__:
666666
if not supercls.__name__.startswith("_") and hasattr(
667667
supercls, "__visit_name__"
@@ -687,7 +687,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
687687
try:
688688
# If the adapted column type does not render the same as the
689689
# original, don't substitute it
690-
if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
690+
if new_coltype.compile(self.dialect) != compiled_type:
691691
# Make an exception to the rule for Float and arrays of Float,
692692
# since at least on PostgreSQL, Float can accurately represent
693693
# both REAL and DOUBLE_PRECISION
@@ -718,13 +718,13 @@ class DeclarativeGenerator(TablesGenerator):
718718
def __init__(
719719
self,
720720
metadata: MetaData,
721-
bind: Connection | Engine,
721+
dialect: Dialect,
722722
options: Sequence[str],
723723
*,
724724
indentation: str = " ",
725725
base_class_name: str = "Base",
726726
):
727-
super().__init__(metadata, bind, options, indentation=indentation)
727+
super().__init__(metadata, dialect, options, indentation=indentation)
728728
self.base_class_name: str = base_class_name
729729
self.inflect_engine = inflect.engine()
730730

@@ -1305,7 +1305,7 @@ class DataclassGenerator(DeclarativeGenerator):
13051305
def __init__(
13061306
self,
13071307
metadata: MetaData,
1308-
bind: Connection | Engine,
1308+
dialect: Dialect,
13091309
options: Sequence[str],
13101310
*,
13111311
indentation: str = " ",
@@ -1315,7 +1315,7 @@ def __init__(
13151315
):
13161316
super().__init__(
13171317
metadata,
1318-
bind,
1318+
dialect,
13191319
options,
13201320
indentation=indentation,
13211321
base_class_name=base_class_name,
@@ -1344,15 +1344,15 @@ class SQLModelGenerator(DeclarativeGenerator):
13441344
def __init__(
13451345
self,
13461346
metadata: MetaData,
1347-
bind: Connection | Engine,
1347+
dialect: Dialect,
13481348
options: Sequence[str],
13491349
*,
13501350
indentation: str = " ",
13511351
base_class_name: str = "SQLModel",
13521352
):
13531353
super().__init__(
13541354
metadata,
1355-
bind,
1355+
dialect,
13561356
options,
13571357
indentation=indentation,
13581358
base_class_name=base_class_name,

src/sqlacodegen/utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from collections.abc import Mapping
55
from typing import Any
66

7-
from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
8-
from sqlalchemy.engine import Connection, Engine
7+
from sqlalchemy import Dialect, PrimaryKeyConstraint, UniqueConstraint
98
from sqlalchemy.sql import ClauseElement
109
from sqlalchemy.sql.elements import TextClause
1110
from sqlalchemy.sql.schema import (
@@ -34,9 +33,9 @@ def get_constraint_sort_key(constraint: Constraint) -> str:
3433
return str(constraint)
3534

3635

37-
def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
36+
def get_compiled_expression(statement: ClauseElement, dialect: Dialect) -> str:
3837
"""Return the statement in a form where any placeholders have been filled in."""
39-
return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
38+
return str(statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
4039

4140

4241
def get_common_fk_constraints(

tests/conftest.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22

33
import pytest
44
from pytest import FixtureRequest
5-
from sqlalchemy.engine import Engine, create_engine
5+
from sqlalchemy import Dialect
6+
from sqlalchemy.dialects import mysql, postgresql, sqlite
67
from sqlalchemy.orm import clear_mappers, configure_mappers
78
from sqlalchemy.schema import MetaData
89

910

1011
@pytest.fixture
11-
def engine(request: FixtureRequest) -> Engine:
12-
dialect = getattr(request, "param", None)
13-
if dialect == "postgresql":
14-
return create_engine("postgresql:///testdb")
15-
elif dialect == "mysql":
16-
return create_engine("mysql+mysqlconnector://testdb")
12+
def dialect(request: FixtureRequest) -> Dialect:
13+
dialect_name = getattr(request, "param", None)
14+
if dialect_name == "postgresql":
15+
return postgresql.dialect()
16+
elif dialect_name == "mysql":
17+
return mysql.mysqlconnector.dialect()
1718
else:
18-
return create_engine("sqlite:///:memory:")
19+
return sqlite.dialect()
1920

2021

2122
@pytest.fixture

tests/test_generator_dataclass.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5+
from sqlalchemy import Dialect
56
from sqlalchemy.dialects.postgresql import UUID
6-
from sqlalchemy.engine import Engine
77
from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
88
from sqlalchemy.sql.expression import text
99
from sqlalchemy.types import INTEGER, VARCHAR
@@ -15,10 +15,10 @@
1515

1616
@pytest.fixture
1717
def generator(
18-
request: FixtureRequest, metadata: MetaData, engine: Engine
18+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
1919
) -> CodeGenerator:
2020
options = getattr(request, "param", [])
21-
return DataclassGenerator(metadata, engine, options)
21+
return DataclassGenerator(metadata, dialect, options)
2222

2323

2424
def test_basic_class(generator: CodeGenerator) -> None:

tests/test_generator_declarative.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy import PrimaryKeyConstraint
6-
from sqlalchemy.engine import Engine
5+
from sqlalchemy import Dialect, PrimaryKeyConstraint
76
from sqlalchemy.schema import (
87
CheckConstraint,
98
Column,
@@ -24,10 +23,10 @@
2423

2524
@pytest.fixture
2625
def generator(
27-
request: FixtureRequest, metadata: MetaData, engine: Engine
26+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
2827
) -> CodeGenerator:
2928
options = getattr(request, "param", [])
30-
return DeclarativeGenerator(metadata, engine, options)
29+
return DeclarativeGenerator(metadata, dialect, options)
3130

3231

3332
def test_indexes(generator: CodeGenerator) -> None:

tests/test_generator_sqlmodel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy.engine import Engine
5+
from sqlalchemy import Dialect
66
from sqlalchemy.schema import (
77
CheckConstraint,
88
Column,
@@ -21,10 +21,10 @@
2121

2222
@pytest.fixture
2323
def generator(
24-
request: FixtureRequest, metadata: MetaData, engine: Engine
24+
request: FixtureRequest, metadata: MetaData, dialect: Dialect
2525
) -> CodeGenerator:
2626
options = getattr(request, "param", [])
27-
return SQLModelGenerator(metadata, engine, options)
27+
return SQLModelGenerator(metadata, dialect, options)
2828

2929

3030
def test_indexes(generator: CodeGenerator) -> None:

0 commit comments

Comments
 (0)