From b42daf7d2d07fb617a96589c349754842d3e9e4f Mon Sep 17 00:00:00 2001 From: sinisaos Date: Mon, 9 Dec 2024 20:54:14 +0100 Subject: [PATCH] update the branch code --- piccolo/columns/reverse_lookup.py | 24 +++++++---- piccolo/query/methods/select.py | 41 +++++++++---------- piccolo/table.py | 9 ++-- .../apps/asgi/commands/files/dummy_server.py | 4 +- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/piccolo/columns/reverse_lookup.py b/piccolo/columns/reverse_lookup.py index 4d622a02d..4feea4766 100644 --- a/piccolo/columns/reverse_lookup.py +++ b/piccolo/columns/reverse_lookup.py @@ -4,7 +4,7 @@ import typing as t from dataclasses import dataclass -from piccolo.columns.base import Selectable +from piccolo.columns.base import QueryString, Selectable from piccolo.columns.column_types import ( JSON, JSONB, @@ -53,7 +53,9 @@ def __init__( for column in columns ) - def get_select_string(self, engine_type: str, with_alias=True) -> str: + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: reverse_lookup_name = self.reverse_lookup._meta.name table1 = self.reverse_lookup._meta.table @@ -74,22 +76,26 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: if engine_type in ("postgres", "cockroach"): if self.as_list: column_name = self.columns[0]._meta.db_column_name - return f""" + return QueryString( + f""" ARRAY( SELECT "{table2_name}"."{column_name}" FROM {reverse_select} ) AS "{reverse_lookup_name}" """ + ) elif not self.serialisation_safe: column_name = table2_pk - return f""" + return QueryString( + f""" ARRAY( SELECT "{table2_name}"."{column_name}" FROM {reverse_select} ) AS "{reverse_lookup_name}" """ + ) else: if len(self.columns) > 0: column_names = ", ".join( @@ -101,7 +107,8 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: f'"{table2_name}"."{column._meta.db_column_name}"' # noqa: E501 for column in table2._meta.columns ) - return f""" + return QueryString( + f""" ( SELECT JSON_AGG("{table2_name}s") FROM ( @@ -109,6 +116,7 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: ) AS "{table2_name}s" ) AS "{reverse_lookup_name}" """ + ) elif engine_type == "sqlite": if len(self.columns) > 1 or not self.serialisation_safe: column_name = table2_pk @@ -118,15 +126,17 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: except IndexError: column_name = table2_pk - return f""" + return QueryString( + f""" ( SELECT group_concat( "{table2_name}"."{column_name}" ) FROM {reverse_select} ) - AS "{table2_name}s [M2M]" + AS "{reverse_lookup_name} [M2M]" """ + ) else: raise ValueError(f"{engine_type} is an unrecognised engine type") diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 4d77b3bf1..995bf4530 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -250,7 +250,7 @@ async def _splice_m2m_rows( secondary_table: t.Type[Table], secondary_table_pk: Column, m2m_name: str, - m2m_select: M2MSelect, + m2m_select: t.Union[M2MSelect, ReverseLookupSelect], as_list: bool = False, ): row_ids = list( @@ -287,7 +287,7 @@ async def _splice_m2m_rows( row[m2m_name] = [extra_rows_map.get(i) for i in row[m2m_name]] return response - async def response_handler(self, response): + async def response_handler(self, response: t.List[t.Dict[str, t.Any]]): m2m_selects = [ i for i in self.columns_delegate.selected_columns @@ -409,12 +409,9 @@ async def response_handler(self, response): ) try: for row in response: - data = row[f"{reverse_lookup_name}"] - row[f"{reverse_lookup_name}"] = ( - [ - value_type(i) - for i in row[f"{reverse_lookup_name}"] - ] + data = row[reverse_lookup_name] + row[reverse_lookup_name] = ( + [value_type(i) for i in row[reverse_lookup_name]] if data else [] ) @@ -436,7 +433,7 @@ async def response_handler(self, response): response, reverse_table, reverse_table._meta.primary_key, - f"{reverse_lookup_name}", + reverse_lookup_name, reverse_lookup_select, as_list=True, ) @@ -449,11 +446,11 @@ async def response_handler(self, response): 0 ]._meta.name for row in response: - if row[f"{reverse_lookup_name}"] is None: - row[f"{reverse_lookup_name}"] = [] - row[f"{reverse_lookup_name}"] = [ + if row[reverse_lookup_name] is None: + row[reverse_lookup_name] = [] + row[reverse_lookup_name] = [ {column_name: i} - for i in row[f"{reverse_lookup_name}"] + for i in row[reverse_lookup_name] ] elif ( len(reverse_lookup_select.columns) == 0 @@ -464,7 +461,7 @@ async def response_handler(self, response): set( itertools.chain( *[ - row[f"{reverse_lookup_name}"] + row[reverse_lookup_name] for row in response ] ) @@ -500,16 +497,16 @@ async def response_handler(self, response): for row in extra_rows } for row in response: - row[f"{reverse_lookup_name}"] = [ + row[reverse_lookup_name] = [ extra_rows_map.get(i) - for i in row[f"{reverse_lookup_name}"] + for i in row[reverse_lookup_name] ] else: response = await self._splice_m2m_rows( response, reverse_table, reverse_table._meta.primary_key, - f"{reverse_lookup_name}", + reverse_lookup_name, reverse_lookup_select, as_list=False, ) @@ -522,8 +519,8 @@ async def response_handler(self, response): and reverse_lookup_select.load_json ): for row in response: - data = row[reverse_lookup_select.columns[0]] - row[reverse_lookup_select.columns[0]] = [ + data = row[str(reverse_lookup_select.columns[0])] + row[str(reverse_lookup_select.columns[0])] = [ load_json(i) for i in data ] @@ -532,8 +529,8 @@ async def response_handler(self, response): # are returned as a JSON string, so we need to deserialise # it. for row in response: - data = row[f"{reverse_lookup_name}"] - row[f"{reverse_lookup_name}"] = ( + data = row[reverse_lookup_name] + row[reverse_lookup_name] = ( load_json(data) if data else [] ) else: @@ -544,7 +541,7 @@ async def response_handler(self, response): response, reverse_table, reverse_table._meta.primary_key, - f"{reverse_lookup_name}", + reverse_lookup_name, reverse_lookup_select, as_list=False, ) diff --git a/piccolo/table.py b/piccolo/table.py index 120e1cbf1..bb4683803 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -89,7 +89,9 @@ class TableMeta: tags: t.List[str] = field(default_factory=list) help_text: t.Optional[str] = None _db: t.Optional[Engine] = None - m2m_relationships: t.List[M2M] = field(default_factory=list) + m2m_relationships: t.List[t.Union[M2M, ReverseLookup]] = field( + default_factory=list + ) schema: t.Optional[str] = None # Records reverse foreign key relationships - i.e. when the current table @@ -279,7 +281,7 @@ def __init_subclass__( email_columns: t.List[Email] = [] auto_update_columns: t.List[Column] = [] primary_key: t.Optional[Column] = None - m2m_relationships: t.List[M2M] = [] + m2m_relationships: t.List[t.Union[M2M, ReverseLookup]] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -330,8 +332,7 @@ def __init_subclass__( if isinstance(attribute, (M2M, ReverseLookup)): attribute._meta._name = attribute_name attribute._meta._table = cls - if isinstance(attribute, M2M): - m2m_relationships.append(attribute) + m2m_relationships.append(attribute) if not primary_key: primary_key = cls._create_serial_primary_key() diff --git a/tests/apps/asgi/commands/files/dummy_server.py b/tests/apps/asgi/commands/files/dummy_server.py index 9b83470a3..a4807aa66 100644 --- a/tests/apps/asgi/commands/files/dummy_server.py +++ b/tests/apps/asgi/commands/files/dummy_server.py @@ -3,7 +3,7 @@ import sys import typing as t -from httpx import AsyncClient +from httpx import ASGITransport, AsyncClient async def dummy_server(app: t.Union[str, t.Callable] = "app:app"): @@ -24,7 +24,7 @@ async def dummy_server(app: t.Union[str, t.Callable] = "app:app"): module = importlib.import_module(path) app = t.cast(t.Callable, getattr(module, app_name)) - async with AsyncClient(app=app) as client: + async with AsyncClient(transport=ASGITransport(app=app)) as client: response = await client.get("http://localhost:8000") if response.status_code != 200: sys.exit("The app isn't callable!")