25
25
Computed ,
26
26
Constraint ,
27
27
DefaultClause ,
28
+ Dialect ,
28
29
Enum ,
29
30
Float ,
30
31
ForeignKey ,
39
40
UniqueConstraint ,
40
41
)
41
42
from sqlalchemy .dialects .postgresql import JSONB
42
- from sqlalchemy .engine import Connection , Engine
43
43
from sqlalchemy .exc import CompileError
44
44
from sqlalchemy .sql .elements import TextClause
45
45
@@ -94,11 +94,9 @@ class Base:
94
94
class CodeGenerator (metaclass = ABCMeta ):
95
95
valid_options : ClassVar [set [str ]] = set ()
96
96
97
- def __init__ (
98
- self , metadata : MetaData , bind : Connection | Engine , options : Sequence [str ]
99
- ):
97
+ def __init__ (self , metadata : MetaData , dialect : Dialect , options : Sequence [str ]):
100
98
self .metadata : MetaData = metadata
101
- self .bind : Connection | Engine = bind
99
+ self .dialect : Dialect = dialect
102
100
self .options : set [str ] = set (options )
103
101
104
102
# Validate options
@@ -124,12 +122,12 @@ class TablesGenerator(CodeGenerator):
124
122
def __init__ (
125
123
self ,
126
124
metadata : MetaData ,
127
- bind : Connection | Engine ,
125
+ dialect : Dialect ,
128
126
options : Sequence [str ],
129
127
* ,
130
128
indentation : str = " " ,
131
129
):
132
- super ().__init__ (metadata , bind , options )
130
+ super ().__init__ (metadata , dialect , options )
133
131
self .indentation : str = indentation
134
132
self .imports : dict [str , set [str ]] = defaultdict (set )
135
133
self .module_imports : set [str ] = set ()
@@ -562,7 +560,7 @@ def add_fk_options(*opts: Any) -> None:
562
560
]
563
561
add_fk_options (local_columns , remote_columns )
564
562
elif isinstance (constraint , CheckConstraint ):
565
- args .append (repr (get_compiled_expression (constraint .sqltext , self .bind )))
563
+ args .append (repr (get_compiled_expression (constraint .sqltext , self .dialect )))
566
564
elif isinstance (constraint , (UniqueConstraint , PrimaryKeyConstraint )):
567
565
args .extend (repr (col .name ) for col in constraint .columns )
568
566
else :
@@ -608,7 +606,7 @@ def fix_column_types(self, table: Table) -> None:
608
606
# Detect check constraints for boolean and enum columns
609
607
for constraint in table .constraints .copy ():
610
608
if isinstance (constraint , CheckConstraint ):
611
- sqltext = get_compiled_expression (constraint .sqltext , self .bind )
609
+ sqltext = get_compiled_expression (constraint .sqltext , self .dialect )
612
610
613
611
# Turn any integer-like column with a CheckConstraint like
614
612
# "column IN (0, 1)" into a Boolean
@@ -646,7 +644,7 @@ def fix_column_types(self, table: Table) -> None:
646
644
pass
647
645
648
646
# PostgreSQL specific fix: detect sequences from server_default
649
- if column .server_default and self .bind . dialect .name == "postgresql" :
647
+ if column .server_default and self .dialect .name == "postgresql" :
650
648
if isinstance (column .server_default , DefaultClause ) and isinstance (
651
649
column .server_default .arg , TextClause
652
650
):
@@ -661,7 +659,7 @@ def fix_column_types(self, table: Table) -> None:
661
659
column .server_default = None
662
660
663
661
def get_adapted_type (self , coltype : Any ) -> Any :
664
- compiled_type = coltype .compile (self .bind . engine . dialect )
662
+ compiled_type = coltype .compile (self .dialect )
665
663
for supercls in coltype .__class__ .__mro__ :
666
664
if not supercls .__name__ .startswith ("_" ) and hasattr (
667
665
supercls , "__visit_name__"
@@ -687,7 +685,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
687
685
try :
688
686
# If the adapted column type does not render the same as the
689
687
# original, don't substitute it
690
- if new_coltype .compile (self .bind . engine . dialect ) != compiled_type :
688
+ if new_coltype .compile (self .dialect ) != compiled_type :
691
689
# Make an exception to the rule for Float and arrays of Float,
692
690
# since at least on PostgreSQL, Float can accurately represent
693
691
# both REAL and DOUBLE_PRECISION
@@ -718,13 +716,13 @@ class DeclarativeGenerator(TablesGenerator):
718
716
def __init__ (
719
717
self ,
720
718
metadata : MetaData ,
721
- bind : Connection | Engine ,
719
+ dialect : Dialect ,
722
720
options : Sequence [str ],
723
721
* ,
724
722
indentation : str = " " ,
725
723
base_class_name : str = "Base" ,
726
724
):
727
- super ().__init__ (metadata , bind , options , indentation = indentation )
725
+ super ().__init__ (metadata , dialect , options , indentation = indentation )
728
726
self .base_class_name : str = base_class_name
729
727
self .inflect_engine = inflect .engine ()
730
728
@@ -1305,7 +1303,7 @@ class DataclassGenerator(DeclarativeGenerator):
1305
1303
def __init__ (
1306
1304
self ,
1307
1305
metadata : MetaData ,
1308
- bind : Connection | Engine ,
1306
+ dialect : Dialect ,
1309
1307
options : Sequence [str ],
1310
1308
* ,
1311
1309
indentation : str = " " ,
@@ -1315,7 +1313,7 @@ def __init__(
1315
1313
):
1316
1314
super ().__init__ (
1317
1315
metadata ,
1318
- bind ,
1316
+ dialect ,
1319
1317
options ,
1320
1318
indentation = indentation ,
1321
1319
base_class_name = base_class_name ,
@@ -1344,15 +1342,15 @@ class SQLModelGenerator(DeclarativeGenerator):
1344
1342
def __init__ (
1345
1343
self ,
1346
1344
metadata : MetaData ,
1347
- bind : Connection | Engine ,
1345
+ dialect : Dialect ,
1348
1346
options : Sequence [str ],
1349
1347
* ,
1350
1348
indentation : str = " " ,
1351
1349
base_class_name : str = "SQLModel" ,
1352
1350
):
1353
1351
super ().__init__ (
1354
1352
metadata ,
1355
- bind ,
1353
+ dialect ,
1356
1354
options ,
1357
1355
indentation = indentation ,
1358
1356
base_class_name = base_class_name ,
0 commit comments