Skip to content

Commit

Permalink
SQLModel Code generation fixes (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
sheinbergon authored Jan 23, 2025
1 parent f0b5af5 commit a7e38e6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ dynamic = ["version"]

[project.optional-dependencies]
test = [
"sqlacodegen[sqlmodel]",
"pytest >= 7.4",
"coverage >= 7",
"psycopg2-binary",
"mysql-connector-python",
]
sqlmodel = ["sqlmodel >= 0.0.12"]
sqlmodel = ["sqlmodel >= 0.0.22"]
citext = ["sqlalchemy-citext >= 1.7.0"]
geoalchemy2 = ["geoalchemy2 >= 0.11.1"]
pgvector = ["pgvector >= 0.2.4"]
Expand Down
20 changes: 17 additions & 3 deletions src/sqlacodegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def main() -> None:
parser.add_argument(
"--tables", help="tables to process (comma-delimited, default: all)"
)
parser.add_argument("--noviews", action="store_true", help="ignore views")
parser.add_argument(
"--noviews",
action="store_true",
help="ignore views (always true for sqlmodels generator)",
)
parser.add_argument("--outfile", help="file to write output to (default: stdout)")
args = parser.parse_args()

Expand Down Expand Up @@ -81,13 +85,23 @@ def main() -> None:
tables = args.tables.split(",") if args.tables else None
schemas = args.schemas.split(",") if args.schemas else [None]
options = set(args.options.split(",")) if args.options else set()
for schema in schemas:
metadata.reflect(engine, schema, not args.noviews, tables)

# Instantiate the generator
generator_class = generators[args.generator].load()
generator = generator_class(metadata, engine, options)

if not generator.views_supported:
name = generator_class.__name__
print(
f"VIEW models will not be generated when using the '{name}' generator",
file=sys.stderr,
)

for schema in schemas:
metadata.reflect(
engine, schema, (generator.views_supported and not args.noviews), tables
)

# Open the target file (if given)
with ExitStack() as stack:
outfile: TextIO
Expand Down
20 changes: 20 additions & 0 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def __init__(
if invalid_options:
raise ValueError("Unrecognized options: " + ", ".join(invalid_options))

@property
@abstractmethod
def views_supported(self) -> bool:
pass

@abstractmethod
def generate(self) -> str:
"""
Expand Down Expand Up @@ -134,6 +139,10 @@ def __init__(
self.imports: dict[str, set[str]] = defaultdict(set)
self.module_imports: set[str] = set()

@property
def views_supported(self) -> bool:
return True

def generate_base(self) -> None:
self.base = Base(
literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
Expand Down Expand Up @@ -482,6 +491,9 @@ def render_column(
if comment:
kwargs["comment"] = repr(comment)

return self.render_column_callable(is_table, *args, **kwargs)

def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
if is_table:
self.add_import(Column)
return render_callable("Column", *args, kwargs=kwargs)
Expand Down Expand Up @@ -1358,6 +1370,14 @@ def __init__(
base_class_name=base_class_name,
)

@property
def views_supported(self) -> bool:
return False

def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
self.add_import(Column)
return render_callable("Column", *args, kwargs=kwargs)

def generate_base(self) -> None:
self.base = Base(
literal_imports=[],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
class Foo(SQLModel, table=True):
id: Optional[int] = Field(default=None, sa_column=Column('id', Integer, \
primary_key=True))
name: str = Field(sa_column=Column('name', Text, nullable=False))
name: str = Field(sa_column=Column('name', Text))
"""
)

Expand Down

0 comments on commit a7e38e6

Please sign in to comment.