From 14a66ba615f92416e67f2f714bb0698c13128890 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 4 Oct 2024 23:37:29 +0100 Subject: [PATCH] more tests --- piccolo/query/methods/objects.py | 19 +++++++--- tests/table/instance/test_get_related.py | 48 +++++++++++++++++++++--- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index eff04020..f351f131 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -242,18 +242,22 @@ async def run( node: t.Optional[str] = None, in_pool: bool = True, ) -> t.Optional[ReferencedTable]: - references = t.cast( - t.Type[ReferencedTable], - self.foreign_key._foreign_key_meta.resolved_references, - ) + if not self.row._exists_in_db: + raise ValueError("The object doesn't exist in the database.") + + root_table = self.row.__class__ data = ( - await self.row.__class__.select( + await root_table.select( *[ i.as_alias(i._meta.name) for i in self.foreign_key.all_columns() ] ) + .where( + root_table._meta.primary_key + == getattr(self.row, root_table._meta.primary_key._meta.name) + ) .first() .run(node=node, in_pool=in_pool) ) @@ -262,6 +266,11 @@ async def run( if data is None or not any(data.values()): return None + references = t.cast( + t.Type[ReferencedTable], + self.foreign_key._foreign_key_meta.resolved_references, + ) + referenced_object = references(**data) referenced_object._exists_in_db = True return referenced_object diff --git a/tests/table/instance/test_get_related.py b/tests/table/instance/test_get_related.py index 6cae3b9f..a5493891 100644 --- a/tests/table/instance/test_get_related.py +++ b/tests/table/instance/test_get_related.py @@ -10,6 +10,9 @@ class TestGetRelated(AsyncTableTest): async def asyncSetUp(self): await super().asyncSetUp() + # Setup two pairs of manager/band, so we can make sure the correct + # objects are returned. + self.manager = Manager(name="Guido") await self.manager.save() @@ -18,13 +21,25 @@ async def asyncSetUp(self): ) await self.band.save() + self.manager_2 = Manager(name="Graydon") + await self.manager_2.save() + + self.band_2 = Band( + name="Rustaceans", manager=self.manager_2.id, popularity=100 + ) + await self.band_2.save() + async def test_foreign_key(self) -> None: """ Make sure you can get a related object from another object instance. """ manager = await self.band.get_related(Band.manager) assert manager is not None - self.assertTrue(manager.name == "Guido") + self.assertTrue(manager.id == self.manager.id) + + manager_2 = await self.band_2.get_related(Band.manager) + assert manager_2 is not None + self.assertTrue(manager_2.id == self.manager_2.id) async def test_non_foreign_key(self): """ @@ -38,7 +53,7 @@ async def test_string(self): Make sure it also works using a string representation of a foreign key. """ manager = t.cast(Manager, await self.band.get_related("manager")) - self.assertTrue(manager.name == "Guido") + self.assertTrue(manager.id == self.manager.id) async def test_invalid_string(self): """ @@ -51,12 +66,33 @@ async def test_multiple_levels(self): """ Make sure ``get_related`` works multiple levels deep. """ - concert = Concert(band_1=self.band) + concert = Concert(band_1=self.band, band_2=self.band_2) await concert.save() manager = await concert.get_related(Concert.band_1._.manager) assert manager is not None - self.assertTrue(manager.name == "Guido") + self.assertTrue(manager.id == self.manager.id) + + manager_2 = await concert.get_related(Concert.band_2._.manager) + assert manager_2 is not None + self.assertTrue(manager_2.id == self.manager_2.id) - band_2_manager = await concert.get_related(Concert.band_2._.manager) - assert band_2_manager is None + async def test_no_match(self): + """ + If not related object exists, make sure ``None`` is returned. + """ + concert = Concert(band_1=self.band, band_2=None) + await concert.save() + + manager_2 = await concert.get_related(Concert.band_2._.manager) + assert manager_2 is None + + async def test_not_in_db(self): + """ + If the object we're calling ``get_related`` on doesn't exist in the + database, then make sure an error is raised. + """ + concert = Concert(band_1=self.band, band_2=self.band_2) + + with self.assertRaises(ValueError): + await concert.get_related(Concert.band_1._.manager)