Skip to content

Commit

Permalink
Supporting SQLAlchemy Core and old versions of SQLAlchemy (#14)
Browse files Browse the repository at this point in the history
* add support for sqlalchemy tables

* add support for old versions of sqlalchemy
  • Loading branch information
vshender authored May 15, 2024
1 parent 77dbc03 commit eb38580
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 208 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ jobs:
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: poetry install
- uses: actions/cache@v3
name: Define a cache for the tox based on the tox config file
with:
path: ./.tox
key: tox-${{ hashFiles('tox.ini') }}
- name: Run unit tests
run: poetry run pytest
run: poetry run tox
- name: Run lint
run: poetry run ruff --output-format=github .

Expand All @@ -52,7 +57,7 @@ jobs:
path: ./.venv
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: poetry install
run: poetry install --with inttests
- name: Install atlas
uses: ariga/setup-atlas@master
- name: Run Test as Standalone
Expand Down
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*~

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -159,4 +161,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

TODO.md
# Direnv
.envrc

TODO.md
106 changes: 65 additions & 41 deletions atlas_provider_sqlalchemy/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,91 @@
import importlib.util
import inspect
from pathlib import Path
from sqlalchemy import create_mock_engine
from sqlalchemy.orm import DeclarativeBase
from typing import Type, Set, List
from typing import Any, Protocol

import sqlalchemy as sa


class DBTableDesc(Protocol):
"""Database table description (SQLAlchemy table or model)."""

metadata: sa.MetaData


class ModuleImportError(Exception):
pass


class ModelsNotFoundError(Exception):
pass


def get_declarative_base(models_dir: Path, skip_errors: bool = False) -> Type[DeclarativeBase]:
"""
Walk the directory tree starting at the root, import all models, and return 1 of them, as they all keep a
reference to the Metadata object. The way sqlalchemy works, you must import all classes in order for them to be
registered in Metadata.
def sqlalchemy_version() -> tuple[int, ...]:
"""Get major and minor version of sqlalchemy."""

return tuple(int(x) for x in sa.__version__.split("."))


def create_mock_engine(url: str, executor: Any) -> Any:
"""Create a "mock" engine used for echoing DDL."""

if sqlalchemy_version() < (1, 4):
return sa.create_engine(url, strategy="mock", executor=executor)
else:
return sa.create_mock_engine(url, executor)


def get_metadata(db_dir: Path, skip_errors: bool = False) -> sa.MetaData:
"""Walk the directory tree starting at the root, import all models and
tables, and return metadata for one of them, as they all keep a reference
to the `MetaData` object. The way SQLAlchemy works, you must import all
models and tables in order for them to be registered in metadata.
"""

models: Set[Type[DeclarativeBase]] = set()
for root, _, _ in os.walk(models_dir):
python_file_paths = Path(root).glob('*.py')
metadata: set[sa.MetaData] = set()

for root, _, _ in os.walk(db_dir):
python_file_paths = Path(root).glob("*.py")
for file_path in python_file_paths:
try:
module_spec = importlib.util.spec_from_file_location(
file_path.stem, file_path)
file_path.stem,
file_path,
)
if module_spec and module_spec.loader:
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
except Exception as e:
if skip_errors:
continue
print(f'{e.__class__.__name__}: {str(e)} in {file_path}')
print("To skip on failed import, run: atlas-provider-sqlalchemy --skip-errors")
exit(1)
continue
classes = {c[1]
for c in inspect.getmembers(module, inspect.isclass)
if hasattr(c[1], "metadata") and c[1] is not DeclarativeBase}
models.update(classes)
if not models:
print('Found no sqlalchemy models in the directory tree.')
exit(1)
return models.pop()


def dump_ddl(dialect_driver: str, base: Type[DeclarativeBase]) -> Type[DeclarativeBase]:
"""
Creates a mock engine and dumps its DDL to stdout
"""

def dump(sql, *multiparams, **params):
print(str(sql.compile(dialect=engine.dialect)).replace('\t', '').replace('\n', ''), end=';\n\n')
raise ModuleImportError(f"{e.__class__.__name__}: {str(e)} in {file_path}")

engine = create_mock_engine(f'{dialect_driver}://', dump)
base.metadata.create_all(engine, checkfirst=False)
return base
ms = {
v.metadata
for (_, v) in inspect.getmembers(module)
if hasattr(v, "metadata") and isinstance(v.metadata, sa.MetaData)
}
metadata.update(ms)

if not metadata:
raise ModelsNotFoundError("Found no sqlalchemy models/tables in the directory tree.")

return metadata.pop()


def dump_ddl(dialect_driver: str, metadata: sa.MetaData) -> sa.MetaData:
"""Dump DDL statements for the given metadata to stdout."""

def dump(sql, *multiparams, **params):
print(str(sql.compile(dialect=engine.dialect)).replace("\t", "").replace("\n", ""), end=";\n\n")

engine = create_mock_engine(f"{dialect_driver}://", dump)
metadata.create_all(engine, checkfirst=False)
return metadata

def get_import_path_from_path(path: Path, root_dir: Path) -> str:
import_path = '.'.join(path.relative_to(
root_dir).parts).replace(path.suffix, '')
return import_path

def print_ddl(dialect_driver: str, models: list[DBTableDesc]) -> None:
"""Dump DDL statements for the metadata from the given models/tables to stdout."""

def print_ddl(dialect_driver: str, models: List[Type[DeclarativeBase]]):
dump_ddl(dialect_driver=dialect_driver, base=models[0])
dump_ddl(dialect_driver=dialect_driver, metadata=models[0].metadata)
32 changes: 23 additions & 9 deletions atlas_provider_sqlalchemy/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
from pathlib import Path
from enum import Enum
from pathlib import Path

import typer
from sqlalchemy.orm import DeclarativeBase
from typing import Type
from atlas_provider_sqlalchemy.ddl import get_declarative_base, dump_ddl
from sqlalchemy import MetaData

from atlas_provider_sqlalchemy.ddl import (
ModuleImportError,
ModelsNotFoundError,
dump_ddl,
get_metadata,
)

app = typer.Typer(no_args_is_help=True)

Expand All @@ -17,9 +23,9 @@ class Dialect(str, Enum):
mssql = "mssql"


def run(dialect: Dialect, path: Path, skip_errors: bool = False) -> Type[DeclarativeBase]:
base = get_declarative_base(path, skip_errors)
return dump_ddl(dialect.value, base)
def run(dialect: Dialect, path: Path, skip_errors: bool = False) -> MetaData:
metadata = get_metadata(path, skip_errors)
return dump_ddl(dialect.value, metadata)


@app.command()
Expand All @@ -29,8 +35,16 @@ def load(dialect: Dialect = Dialect.mysql,
):
if path is None:
path = Path(os.getcwd())
run(dialect, path, skip_errors)
try:
run(dialect, path, skip_errors)
except ModuleImportError as e:
print(e)
print("To skip on failed import, run: atlas-provider-sqlalchemy --skip-errors")
exit(1)
except ModelsNotFoundError as e:
print(e)
exit(1)


if __name__ == "__main__":
app(prog_name='atlas-provider-sqlalchemy')
app(prog_name="atlas-provider-sqlalchemy")
Loading

0 comments on commit eb38580

Please sign in to comment.