Skip to content

Commit

Permalink
chore: add tests on sqlalchemy datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarreau committed Nov 29, 2024
1 parent 8196f43 commit 0ebe10c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
3 changes: 0 additions & 3 deletions src/_example/django/django_demo/app/forest_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def customize_forest(agent: DjangoAgent):

agent.customize_collection("address").add_segment("France", segment_addr_fr("address"))
agent.customize_collection("app_address").add_segment("France", segment_addr_fr("app_address"))
agent.customize_collection("app_customer_blocked_customer").rename_field("from_customer", "from").rename_field(
"to_customer", "to"
)

# # ## ADDRESS
agent.customize_collection("app_address").add_segment(
Expand Down
16 changes: 14 additions & 2 deletions src/datasource_sqlalchemy/tests/fixture/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os
from datetime import date, datetime

import sqlalchemy # type: ignore
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, create_engine, func, types
from sqlalchemy.orm import Session, declarative_base, relationship, validates
from sqlalchemy.orm import Session, relationship, validates

use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2"
test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", "test_db.sql"))
engine = create_engine(f"sqlite:///{test_db_path}", echo=False)
fixtures_dir = os.path.abspath(os.path.join(__file__, ".."))
Expand Down Expand Up @@ -42,7 +44,17 @@ def __import__(cls, d):
return cls(**params)


Base = declarative_base(cls=_Base)
if use_sqlalchemy_2:
from sqlalchemy.orm import DeclarativeBase

class Base(DeclarativeBase, _Base):
pass

else:
from sqlalchemy.orm import declarative_base

Base = declarative_base(cls=_Base)

Base.metadata.bind = engine


Expand Down
92 changes: 92 additions & 0 deletions src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import os
from unittest import TestCase
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -93,3 +95,93 @@ def test_with_models(self):

assert len(datasource._collections) == 4
assert datasource.get_collection("address").datasource == datasource


class TestSQLAlchemyDatasourceConnectionQueryCreation(TestCase):
def test_should_not_create_native_query_connection_if_no_params(self):
ds = SqlAlchemyDatasource(models.Base)
self.assertEqual(ds.get_native_query_connections(), [])

def test_should_create_native_query_connection_to_default_if_string_is_set(self):
ds = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy")
self.assertEqual(ds.get_native_query_connections(), ["sqlalchemy"])


class TestSQLAlchemyDatasourceNativeQueryExecution(TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.loop = asyncio.new_event_loop()
if os.path.exists(models.test_db_path):
os.remove(models.test_db_path)
models.create_test_database()
models.load_fixtures()
cls.sql_alchemy_datasource = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy")

def test_should_raise_if_connection_is_not_known_by_datasource(self):
self.assertRaisesRegex(
SqlAlchemyDatasourceException,
r"The native query connection 'foo' doesn't belongs to this datasource.",
self.loop.run_until_complete,
self.sql_alchemy_datasource.execute_native_query("foo", "select * from blabla", {}),
)

def test_should_correctly_execute_query(self):
result = self.loop.run_until_complete(
self.sql_alchemy_datasource.execute_native_query(
"sqlalchemy", "select * from customer where id <= 2 order by id;", {}
)
)
self.assertEqual(
result,
[
{"id": 1, "first_name": "David", "last_name": "Myers", "age": 112},
{"id": 2, "first_name": "Thomas", "last_name": "Odom", "age": 92},
],
)

def test_should_correctly_execute_query_with_formatting(self):
result = self.loop.run_until_complete(
self.sql_alchemy_datasource.execute_native_query(
"sqlalchemy",
"""select *
from customer
where first_name = %(first_name)s
and last_name = %(last_name)s
order by id""",
{"first_name": "David", "last_name": "Myers"},
)
)
self.assertEqual(
result,
[
{"id": 1, "first_name": "David", "last_name": "Myers", "age": 112},
],
)

def test_should_correctly_execute_query_with_percent(self):
result = self.loop.run_until_complete(
self.sql_alchemy_datasource.execute_native_query(
"sqlalchemy",
"""select *
from customer
where first_name like 'Dav\\%'
order by id""",
{},
)
)

self.assertEqual(
result,
[
{"id": 1, "first_name": "David", "last_name": "Myers", "age": 112},
],
)

def test_should_correctly_raise_exception_during_sql_error(self):
self.assertRaisesRegex(
SqlAlchemyDatasourceException,
r"no such table: blabla",
self.loop.run_until_complete,
self.sql_alchemy_datasource.execute_native_query("sqlalchemy", "select * from blabla", {}),
)

0 comments on commit 0ebe10c

Please sign in to comment.