Skip to content

Commit

Permalink
update the branch code
Browse files Browse the repository at this point in the history
  • Loading branch information
sinisaos committed Dec 9, 2024
1 parent 4b58c8c commit b42daf7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 35 deletions.
24 changes: 17 additions & 7 deletions piccolo/columns/reverse_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t
from dataclasses import dataclass

from piccolo.columns.base import Selectable
from piccolo.columns.base import QueryString, Selectable
from piccolo.columns.column_types import (
JSON,
JSONB,
Expand Down Expand Up @@ -53,7 +53,9 @@ def __init__(
for column in columns
)

def get_select_string(self, engine_type: str, with_alias=True) -> str:
def get_select_string(
self, engine_type: str, with_alias=True
) -> QueryString:
reverse_lookup_name = self.reverse_lookup._meta.name

table1 = self.reverse_lookup._meta.table
Expand All @@ -74,22 +76,26 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str:
if engine_type in ("postgres", "cockroach"):
if self.as_list:
column_name = self.columns[0]._meta.db_column_name
return f"""
return QueryString(
f"""
ARRAY(
SELECT
"{table2_name}"."{column_name}"
FROM {reverse_select}
) AS "{reverse_lookup_name}"
"""
)
elif not self.serialisation_safe:
column_name = table2_pk
return f"""
return QueryString(
f"""
ARRAY(
SELECT
"{table2_name}"."{column_name}"
FROM {reverse_select}
) AS "{reverse_lookup_name}"
"""
)
else:
if len(self.columns) > 0:
column_names = ", ".join(
Expand All @@ -101,14 +107,16 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str:
f'"{table2_name}"."{column._meta.db_column_name}"' # noqa: E501
for column in table2._meta.columns
)
return f"""
return QueryString(
f"""
(
SELECT JSON_AGG("{table2_name}s")
FROM (
SELECT {column_names} FROM {reverse_select}
) AS "{table2_name}s"
) AS "{reverse_lookup_name}"
"""
)
elif engine_type == "sqlite":
if len(self.columns) > 1 or not self.serialisation_safe:
column_name = table2_pk
Expand All @@ -118,15 +126,17 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str:
except IndexError:
column_name = table2_pk

return f"""
return QueryString(
f"""
(
SELECT group_concat(
"{table2_name}"."{column_name}"
)
FROM {reverse_select}
)
AS "{table2_name}s [M2M]"
AS "{reverse_lookup_name} [M2M]"
"""
)
else:
raise ValueError(f"{engine_type} is an unrecognised engine type")

Expand Down
41 changes: 19 additions & 22 deletions piccolo/query/methods/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def _splice_m2m_rows(
secondary_table: t.Type[Table],
secondary_table_pk: Column,
m2m_name: str,
m2m_select: M2MSelect,
m2m_select: t.Union[M2MSelect, ReverseLookupSelect],
as_list: bool = False,
):
row_ids = list(
Expand Down Expand Up @@ -287,7 +287,7 @@ async def _splice_m2m_rows(
row[m2m_name] = [extra_rows_map.get(i) for i in row[m2m_name]]
return response

async def response_handler(self, response):
async def response_handler(self, response: t.List[t.Dict[str, t.Any]]):
m2m_selects = [
i
for i in self.columns_delegate.selected_columns
Expand Down Expand Up @@ -409,12 +409,9 @@ async def response_handler(self, response):
)
try:
for row in response:
data = row[f"{reverse_lookup_name}"]
row[f"{reverse_lookup_name}"] = (
[
value_type(i)
for i in row[f"{reverse_lookup_name}"]
]
data = row[reverse_lookup_name]
row[reverse_lookup_name] = (
[value_type(i) for i in row[reverse_lookup_name]]
if data
else []
)
Expand All @@ -436,7 +433,7 @@ async def response_handler(self, response):
response,
reverse_table,
reverse_table._meta.primary_key,
f"{reverse_lookup_name}",
reverse_lookup_name,
reverse_lookup_select,
as_list=True,
)
Expand All @@ -449,11 +446,11 @@ async def response_handler(self, response):
0
]._meta.name
for row in response:
if row[f"{reverse_lookup_name}"] is None:
row[f"{reverse_lookup_name}"] = []
row[f"{reverse_lookup_name}"] = [
if row[reverse_lookup_name] is None:
row[reverse_lookup_name] = []
row[reverse_lookup_name] = [
{column_name: i}
for i in row[f"{reverse_lookup_name}"]
for i in row[reverse_lookup_name]
]
elif (
len(reverse_lookup_select.columns) == 0
Expand All @@ -464,7 +461,7 @@ async def response_handler(self, response):
set(
itertools.chain(
*[
row[f"{reverse_lookup_name}"]
row[reverse_lookup_name]
for row in response
]
)
Expand Down Expand Up @@ -500,16 +497,16 @@ async def response_handler(self, response):
for row in extra_rows
}
for row in response:
row[f"{reverse_lookup_name}"] = [
row[reverse_lookup_name] = [
extra_rows_map.get(i)
for i in row[f"{reverse_lookup_name}"]
for i in row[reverse_lookup_name]
]
else:
response = await self._splice_m2m_rows(
response,
reverse_table,
reverse_table._meta.primary_key,
f"{reverse_lookup_name}",
reverse_lookup_name,
reverse_lookup_select,
as_list=False,
)
Expand All @@ -522,8 +519,8 @@ async def response_handler(self, response):
and reverse_lookup_select.load_json
):
for row in response:
data = row[reverse_lookup_select.columns[0]]
row[reverse_lookup_select.columns[0]] = [
data = row[str(reverse_lookup_select.columns[0])]
row[str(reverse_lookup_select.columns[0])] = [
load_json(i) for i in data
]

Expand All @@ -532,8 +529,8 @@ async def response_handler(self, response):
# are returned as a JSON string, so we need to deserialise
# it.
for row in response:
data = row[f"{reverse_lookup_name}"]
row[f"{reverse_lookup_name}"] = (
data = row[reverse_lookup_name]
row[reverse_lookup_name] = (
load_json(data) if data else []
)
else:
Expand All @@ -544,7 +541,7 @@ async def response_handler(self, response):
response,
reverse_table,
reverse_table._meta.primary_key,
f"{reverse_lookup_name}",
reverse_lookup_name,
reverse_lookup_select,
as_list=False,
)
Expand Down
9 changes: 5 additions & 4 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ class TableMeta:
tags: t.List[str] = field(default_factory=list)
help_text: t.Optional[str] = None
_db: t.Optional[Engine] = None
m2m_relationships: t.List[M2M] = field(default_factory=list)
m2m_relationships: t.List[t.Union[M2M, ReverseLookup]] = field(
default_factory=list
)
schema: t.Optional[str] = None

# Records reverse foreign key relationships - i.e. when the current table
Expand Down Expand Up @@ -279,7 +281,7 @@ def __init_subclass__(
email_columns: t.List[Email] = []
auto_update_columns: t.List[Column] = []
primary_key: t.Optional[Column] = None
m2m_relationships: t.List[M2M] = []
m2m_relationships: t.List[t.Union[M2M, ReverseLookup]] = []

attribute_names = itertools.chain(
*[i.__dict__.keys() for i in reversed(cls.__mro__)]
Expand Down Expand Up @@ -330,8 +332,7 @@ def __init_subclass__(
if isinstance(attribute, (M2M, ReverseLookup)):
attribute._meta._name = attribute_name
attribute._meta._table = cls
if isinstance(attribute, M2M):
m2m_relationships.append(attribute)
m2m_relationships.append(attribute)

if not primary_key:
primary_key = cls._create_serial_primary_key()
Expand Down
4 changes: 2 additions & 2 deletions tests/apps/asgi/commands/files/dummy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import typing as t

from httpx import AsyncClient
from httpx import ASGITransport, AsyncClient


async def dummy_server(app: t.Union[str, t.Callable] = "app:app"):
Expand All @@ -24,7 +24,7 @@ async def dummy_server(app: t.Union[str, t.Callable] = "app:app"):
module = importlib.import_module(path)
app = t.cast(t.Callable, getattr(module, app_name))

async with AsyncClient(app=app) as client:
async with AsyncClient(transport=ASGITransport(app=app)) as client:
response = await client.get("http://localhost:8000")
if response.status_code != 200:
sys.exit("The app isn't callable!")
Expand Down

0 comments on commit b42daf7

Please sign in to comment.