Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLModel Compatibility #271

Open
pavja2 opened this issue Feb 28, 2022 · 8 comments
Open

SQLModel Compatibility #271

pavja2 opened this issue Feb 28, 2022 · 8 comments

Comments

@pavja2
Copy link

pavja2 commented Feb 28, 2022

SQLModel (https://github.com/tiangolo/sqlmodel) is notionally built as a layer on top of SQLAlchemy, so it would be awesome to use SQLAlchemy Continuum to version SQLModel managed databases. However, the basic getting started example throws an error when replacing the SQLAlchemy model with a SQLModel version. I'm not entirely clear on how SQLModel inherits from SQLAlchemy and to what degree compatibility is possible, but if there's some trivial fix to bubble up the missing attributes it would be fantastic.

Tutorial Example, but using SQL Model

from typing import Optional
from sqlalchemy_continuum import make_versioned
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import create_session, configure_mappers
from sqlmodel import SQLModel, Field

make_versioned(user_cls=None)

Base = declarative_base()
class Article(SQLModel, table=True):
    __versioned__ = {}
    __tablename__ = 'article'
    id: Optional[int] = Field(primary_key=True)
    name: Optional[str] = Field(default=None)
    content: Optional[str] = Field(default=None)

configure_mappers()
engine = create_engine('sqlite://')
Base.metadata.create_all(engine)
session = create_session(bind=engine, autocommit=False)

article = Article(name=u'Some article', content=u'Some content')
session.add(article)
session.commit()
print(article.versions[0].name)
article.name = u'Updated name'
session.commit()
print(article.versions[1].name)
article.versions[0].revert()
print(article.name)

Error Output

File env/lib/python3.8/site-packages/sqlalchemy_continuum/factory.py", line 11, in __call__
    registry = Base.registry._class_registry
AttributeError: type object 'SQLModel' has no attribute 'registry'

File "db_scratch.py", line 18, in <module>
    configure_mappers()
  File "/env/lib/python3.8/site-packages/sqlalchemy/orm/mapper.py", line 3388, in configure_mappers
    _configure_registries(_all_registries(), cascade=True)
  File "/env/lib/python3.8/site-packages/sqlalchemy/orm/mapper.py", line 3421, in _configure_registries
    Mapper.dispatch._for_class(Mapper).after_configured()
  File "/env/lib/python3.8/site-packages/sqlalchemy/event/attr.py", line 256, in __call__
    fn(*args, **kw)
  File "/env/lib/python3.8/site-packages/sqlalchemy/orm/events.py", line 743, in wrap
    fn(*arg, **kw)
  File "/env/lib/python3.8/site-packages/sqlalchemy_continuum/builder.py", line 22, in check_reentry
    handler(*args, **kwargs)
  File "/env/lib/python3.8/site-packages/sqlalchemy_continuum/builder.py", line 177, in configure_versioned_classes
    self.build_transaction_class()
  File "/env/lib/python3.8/site-packages/sqlalchemy_continuum/builder.py", line 154, in build_transaction_class
    self.manager.create_transaction_model()
  File "/env/lib/python3.8/site-packages/sqlalchemy_continuum/manager.py", line 166, in create_transaction_model
    self.transaction_cls = self.transaction_cls(self)
  File "/env/lib/python3.8/site-packages/sqlalchemy_continuum/factory.py", line 13, in __call__
    registry = Base._decl_class_registry
AttributeError: type object 'SQLModel' has no attribute '_decl_class_registry'
@alvarolloret
Copy link

Any updates on this? :)

@marksteward
Copy link
Collaborator

Update? I think it's a feature request.

@hasansezertasan
Copy link

Yep, the it sounds like a feature request yet totally out of sqlalchemy-continuum's scope.

@martinezger
Copy link

Hey, someone has a solution for this?

@CiberNin
Copy link

I wish to also register my interest in the provisioning of this compatability.

@dries007
Copy link

I am also interested.

@AlePiccin
Copy link

Can't wait to be compatible. I really need this feature.

@AlePiccin
Copy link

I found a solution. I'm currently using:

fastapi[standard]==0.115.5
pydantic==2.8.2
pydantic-settings==2.2.1
sqlmodel==0.0.22
sqlalchemy==2.0.36
sqlalchemy-utils==0.41.2
sqlalchemy-continuum==1.4.2

Before declaring any model, do the following:

import sqlalchemy_utils

def new_get_declarative_base(model):
    import sqlmodel
    if issubclass(model, sqlmodel.SQLModel):
        return model._sa_registry.generate_base()
    else:
        for parent in model.__bases__:
            try:
                _ = parent.metadata
                return new_get_declarative_base(parent)
            except AttributeError:
                pass
        return model


sqlalchemy_utils.functions.get_declarative_base = new_get_declarative_base


import sqlmodel

sqlmodel.SQLModel.metadata = sqlmodel.main.default_registry.metadata

sqlalchemy_continuum.make_versioned()

class BaseSQLModel(sqlmodel.SQLModel):
    __versioned__ = {}

This should work with SQLModels. Your models should inherit from BaseSQLModel.

If you use SQL Server as your database, you may encounter an error due to a sequence definition in the id column of the TransactionClass. You can rewrite the TransactionClass as shown below to address this issue. Additionally, I use a FastAPIPlugin to retrieve the user from the session. All you need to do is set the user data in the session.info dictionary. Here is the complete code I use:

import sqlalchemy_utils


def new_get_declarative_base(model):
    import sqlmodel
    if issubclass(model, sqlmodel.SQLModel):
        # noinspection PyProtectedMember
        return model._sa_registry.generate_base()
    else:
        for parent in model.__bases__:
            try:
                _ = parent.metadata
                return new_get_declarative_base(parent)
            except AttributeError:
                pass
        return model


sqlalchemy_utils.functions.get_declarative_base = new_get_declarative_base

import sqlalchemy_continuum


def new_create_class(self, manager):
    """
    Create Transaction class.
    """
    import sqlalchemy as sa
    from collections import OrderedDict

    class Transaction(manager.declarative_base, sqlalchemy_continuum.transaction.TransactionBase):
        __tablename__ = 'transaction'
        __versioning_manager__ = manager

        id = sa.Column(sa.types.BigInteger, primary_key=True, autoincrement=True)

        if self.remote_addr:
            remote_addr = sa.Column(sa.String(50))

        if manager.user_cls:
            user_cls = manager.user_cls
            Base = manager.declarative_base
            # noinspection PyProtectedMember
            registry = Base.registry._class_registry

            if isinstance(user_cls, str):
                try:
                    user_cls = registry[user_cls]
                except KeyError:
                    raise sqlalchemy_continuum.ImproperlyConfigured('Could not build relationship between Transaction'
                                                                    ' and %s. %s was not found in declarative class '
                                                                    'registry. Either configure VersioningManager to '
                                                                    'use different user class or disable this '
                                                                    'relationship ' % (user_cls, user_cls))

            user_id = sa.Column(sa.inspect(user_cls).primary_key[0].type,
                sa.ForeignKey(sa.inspect(user_cls).primary_key[0]), index=True)

            user = sa.orm.relationship(user_cls)

        def __repr__(self):
            fields = ['id', 'issued_at', 'user']
            field_values = OrderedDict((field, getattr(self, field)) for field in fields if hasattr(self, field))
            return '<Transaction %s>' % ', '.join(('%s=%r' % (field, value) if not isinstance(value,
                                                                                              int) # We want the following line to ensure that longs get
            # shown without the ugly L suffix on python 2.x
            # versions
            else '%s=%d' % (field, value) for field, value in field_values.items()))

    if manager.options['native_versioning']:
        sqlalchemy_continuum.transaction.create_triggers(Transaction)
    return Transaction


sqlalchemy_continuum.transaction.TransactionFactory.create_class = new_create_class

from sqlalchemy_continuum.plugins import Plugin


class FastAPIPlugin(Plugin):

    def transaction_args(self, uow, session):
        return {'user_id': session.info.get('id_usuario'), 'remote_addr': None, }


import sqlmodel
from sqlalchemy.orm import declared_attr

constraint_naming_conventions = {"ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s",
                                 "ck": "ck_%(table_name)s_%(constraint_name)s",
                                 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
                                 "pk": "pk_%(table_name)s"}

sqlmodel.main.default_registry.metadata = sqlmodel.MetaData(naming_convention=constraint_naming_conventions)

sqlmodel.SQLModel.metadata = sqlmodel.main.default_registry.metadata

sqlalchemy_continuum.make_versioned(user_cls="Usuario", plugins=[FastAPIPlugin()])


class BaseSQLModel(sqlmodel.SQLModel):
    __versioned__ = {}

    @declared_attr  # type: ignore
    def __tablename__(cls) -> str:
        return cls.__name__

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants