Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Column and metadata in SQLModel Generator #306

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,84 @@ def __init__(
base_class_name=base_class_name,
)

def generate_models(self) -> list[Model]:
models_by_table_name: dict[str, Model] = {}

# Pick association tables from the metadata into their own set, don't process
# them normally
links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
for table in self.metadata.sorted_tables:
qualified_name = qualified_table_name(table)

# Link tables have exactly two foreign key constraints and all columns are
# involved in them
fk_constraints = sorted(
table.foreign_key_constraints, key=get_constraint_sort_key
)
if len(fk_constraints) == 2 and all(
col.foreign_keys for col in table.columns
):
model = models_by_table_name[qualified_name] = Model(table)
tablename = fk_constraints[0].elements[0].column.table.name
links[tablename].append(model)
continue

# Only form model classes for tables that have a primary key and are not
# association tables
if not table.primary_key:
models_by_table_name[qualified_name] = Model(table)
else:
model = ModelClass(table)
models_by_table_name[qualified_name] = model

# Fill in the columns
for column in table.c:
column_attr = ColumnAttribute(model, column)
model.columns.append(column_attr)

# Add relationships
for model in models_by_table_name.values():
if isinstance(model, ModelClass):
self.generate_relationships(
model, models_by_table_name, links[model.table.name]
)

# Nest inherited classes in their superclasses to ensure proper ordering
if "nojoined" not in self.options:
for model in list(models_by_table_name.values()):
if not isinstance(model, ModelClass):
continue

pk_column_names = {col.name for col in model.table.primary_key.columns}
for constraint in model.table.foreign_key_constraints:
if set(get_column_names(constraint)) == pk_column_names:
target = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(target, ModelClass):
model.parent_class = target
target.children.append(model)

# Change base if we have both tables and model classes
if any(
not isinstance(model, ModelClass) for model in models_by_table_name.values()
):
TablesGenerator.generate_base(self)

# Collect the imports
self.collect_imports(models_by_table_name.values())

# Rename models and their attributes that conflict with imports or other
# attributes
global_names = {
name for namespace in self.imports.values() for name in namespace
}
for model in models_by_table_name.values():
self.generate_model_name(model, global_names)
global_names.add(model.name)

return list(models_by_table_name.values())

def generate_base(self) -> None:
self.base = Base(
literal_imports=[],
Expand All @@ -1538,7 +1616,6 @@ def generate_base(self) -> None:
def collect_imports(self, models: Iterable[Model]) -> None:
super(DeclarativeGenerator, self).collect_imports(models)
if any(isinstance(model, ModelClass) for model in models):
self.remove_literal_import("sqlalchemy", "MetaData")
self.add_literal_import("sqlmodel", "SQLModel")
self.add_literal_import("sqlmodel", "Field")

Expand Down Expand Up @@ -1570,7 +1647,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
self.add_import(python_type)

def render_module_variables(self, models: list[Model]) -> str:
declarations: list[str] = []
declarations: list[str] = self.base.declarations
if any(not isinstance(model, ModelClass) for model in models):
if self.base.table_metadata_declaration is not None:
declarations.append(self.base.table_metadata_declaration)
Expand Down Expand Up @@ -1616,7 +1693,7 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
kwargs["default"] = None
python_type_name = f"Optional[{python_type_name}]"

rendered_column = self.render_column(column, True)
rendered_column = self.render_column(column, True, is_table=True)
kwargs["sa_column"] = f"{rendered_column}"
rendered_field = render_callable("Field", kwargs=kwargs)
return f"{column_attr.name}: {python_type_name} = {rendered_field}"
Expand Down