Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support old way of declarative_base and fix fail messages #8

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ See [atlasgo.io](https://atlasgo.io/getting-started#installation) for more insta

Install the provider by running:
```bash
# The Provider works by importing your SQLAlchemy models and extracting the schema from them.
# Therefore, you will need to run the provider from within your project's Python environment.
pip install atlas-provider-sqlalchemy
```

Expand Down
22 changes: 11 additions & 11 deletions atlas_provider_sqlalchemy/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ModelsNotFoundError(Exception):
pass


def get_declarative_base(models_dir: Path, debug: bool = False) -> Type[DeclarativeBase]:
def get_declarative_base(models_dir: Path, skip_errors: bool = False) -> Type[DeclarativeBase]:
"""
Walk the directory tree starting at the root, import all models, and return 1 of them, as they all keep a
reference to the Metadata object. The way sqlalchemy works, you must import all classes in order for them to be
Expand All @@ -29,20 +29,20 @@ def get_declarative_base(models_dir: Path, debug: bool = False) -> Type[Declarat
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
except Exception as e:
if debug:
print(f'{e.__class__.__name__}: {str(e)}')
# TODO: handle nicer
if skip_errors:
continue
print(f'{e.__class__.__name__}: {str(e)} in {file_path}')
print("To skip on failed import, run: atlas-provider-sqlalchemy --skip-errors")
exit(1)
continue
classes = {c[1]
for c in inspect.getmembers(module, inspect.isclass)
if issubclass(c[1], DeclarativeBase) and c[1] is not DeclarativeBase}
if hasattr(c[1], "metadata") and c[1] is not DeclarativeBase}
models.update(classes)
try:
model = models.pop()
except KeyError:
raise ModelsNotFoundError(
'Found no sqlalchemy models in the directory tree.')
return model
if not models:
print('Found no sqlalchemy models in the directory tree.')
exit(1)
return models.pop()


def dump_ddl(dialect_driver: str, base: Type[DeclarativeBase]) -> Type[DeclarativeBase]:
Expand Down
12 changes: 6 additions & 6 deletions atlas_provider_sqlalchemy/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pathlib import Path
from typing import Optional
from enum import Enum
import typer
from sqlalchemy.orm import DeclarativeBase
Expand All @@ -18,18 +17,19 @@ class Dialect(str, Enum):
mssql = "mssql"


def run(dialect: Dialect, path: Path, debug: bool = False) -> Type[DeclarativeBase]:
base = get_declarative_base(path, debug)
def run(dialect: Dialect, path: Path, skip_errors: bool = False) -> Type[DeclarativeBase]:
base = get_declarative_base(path, skip_errors)
return dump_ddl(dialect.value, base)


@app.command()
def load(dialect: Dialect = Dialect.mysql,
path: Optional[Path] = typer.Option(None, exists=True, help="Path to directory of the sqlalchemy models."),
debug: bool = False):
path: Path = typer.Option(exists=True, help="Path to directory of the sqlalchemy models."),
skip_errors: bool = typer.Option(False, help="Skip errors when loading models.")
):
if path is None:
path = Path(os.getcwd())
run(dialect, path, debug)
run(dialect, path, skip_errors)


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions tests/old_models/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.ext.declarative import declarative_base

# Using the old way of declaring a declarative base
Base = declarative_base()


class User(Base):
__tablename__ = "user_account"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
fullname: Mapped[Optional[str]] = mapped_column(String(30))
addresses: Mapped[List["Address"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)

def __repr__(self) -> str:
return f"User(id={self.id!r}, name={self.name!r}, fullname={self.fullname!r})"


class Address(Base):
__tablename__ = "address"
id: Mapped[int] = mapped_column(primary_key=True)
email_address: Mapped[str] = mapped_column(String(30))
user_id: Mapped[int] = mapped_column(ForeignKey("user_account.id"))
user: Mapped["User"] = relationship(back_populates="addresses")

def __repr__(self) -> str:
return f"Address(id={self.id!r}, email_address={self.email_address!r})"
22 changes: 16 additions & 6 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from atlas_provider_sqlalchemy.ddl import ModelsNotFoundError
from atlas_provider_sqlalchemy.main import run, get_declarative_base, Dialect
from pathlib import Path
from sqlalchemy.orm import DeclarativeBase
Expand All @@ -23,9 +22,18 @@ def test_run_mysql(capsys):
base.metadata.clear()


def test_get_declarative_base():
base = get_declarative_base(Path("tests"))
assert issubclass(base, DeclarativeBase)
def test_run_old_declarative_base(capsys):
with open('tests/ddl_mysql.sql', 'r') as f:
expected_ddl = f.read()
base = run(Dialect.mysql, Path("tests/old_models"))
captured = capsys.readouterr()
assert captured.out == expected_ddl
base.metadata.clear()


def test_get_old_declarative_base():
base = get_declarative_base(Path("tests/old_models"))
assert not issubclass(base, DeclarativeBase)
base.metadata.clear()


Expand All @@ -35,7 +43,9 @@ def test_get_declarative_base_explicit_path():
base.metadata.clear()


def test_get_declarative_base_explicit_path_fail():
with pytest.raises(ModelsNotFoundError, match='Found no sqlalchemy models in the directory tree.'):
def test_get_declarative_base_explicit_path_fail(capsys):
with pytest.raises(SystemExit):
base = get_declarative_base(Path("nothing/here"))
base.metadata.clear()
captured = capsys.readouterr()
assert captured.out == 'Found no sqlalchemy models in the directory tree.\n'