Skip to content

Commit

Permalink
Add sort and order by query params to v3 endpoints (#285)
Browse files Browse the repository at this point in the history
* order by classes

* refactor query builder to handle sort and order by statements

* add sort and order by to v3 endpoints

resolves #163
---------

Co-authored-by: Gabriel Fosse <[email protected]>
  • Loading branch information
russbiggs and Gabriel Fosse authored Sep 27, 2023
1 parent 29d0b7a commit 75a2efd
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 135 deletions.
228 changes: 132 additions & 96 deletions openaq_api/openaq_api/v3/models/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import types
import weakref
from datetime import date, datetime
from enum import StrEnum
from enum import StrEnum, auto
from types import FunctionType
from typing import Annotated, Any
from abc import ABC

import fastapi
import humps
Expand All @@ -26,6 +27,7 @@
)
from pydantic_core import CoreSchema, core_schema


logger = logging.getLogger("queries")

maxint = 2147483647
Expand Down Expand Up @@ -177,101 +179,7 @@ def _get_type_adapter(cls) -> TypeAdapter:
raise NotImplementedError("should be overridden in metaclass")


class QueryBuilder(object):
"""A utility class to wrap multiple QueryBaseModel classes"""

def __init__(self, query: type):
"""
Args:
query: a class which inherits from one or more pydantic query
models, QueryBaseModel.
"""
self.query = query

def _bases(self) -> list[type]:
"""inspects the object and returns base classes
Removes primitive objects in ancestry to only include Pydantic Query
and Path models
Returns:
a sorted list of base classes
"""
bases = list(inspect.getmro(self.query.__class__))[
:-3
] # removes object primitives <class '__main__.QueryBaseModel'>, <class 'pydantic.main.BaseModel'>, <class 'object'>
bases_sorted = sorted(
bases, key=operator.attrgetter("__name__")
) # sort to ensure consistent order for reliability in testing
return bases_sorted

def fields(self) -> str:
"""
loops through all ancestor classes and calls
their respective fields() methods to concatenate
into additional fields for select
Returns:
"""
fields = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "fields", None)):
if base.fields(self.query):
fields.append(base.fields(self.query))
if len(fields):
fields = list(set(fields))
return "\n," + ("\n,").join(fields)
else:
return ""

def pagination(self) -> str:
pagination = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "pagination", None)):
if base.pagination(self.query):
pagination.append(base.pagination(self.query))
if len(pagination):
pagination = list(set(pagination))
return "\n" + ("\n,").join(pagination)
else:
return ""

def params(self) -> dict:
return self.query.model_dump(exclude_unset=True, by_alias=True)

@staticmethod
def total() -> str:
"""Generates the SQL for the count of total records found.
Returns:
SQL string for the count of total records found
"""
return ", COUNT(1) OVER() as found"

def where(self) -> str:
"""Introspects object ancestors and calls respective where() methods.
Returns:
SQL string of all ancestor WHERE clauses.
"""
where = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "where", None)):
if base.where(self.query):
where.append(base.where(self.query))
if len(where):
where = list(set(where))
where.sort() # ensure the order is consistent for testing
return "WHERE " + ("\nAND ").join(where)
else:
return ""


class QueryBaseModel(BaseModel):
class QueryBaseModel(ABC, BaseModel):
"""Base class for building query objects.
All query objects should inherit this model and can implement
Expand Down Expand Up @@ -325,6 +233,20 @@ def pagination(self):
...


class SortOrder(StrEnum):
ASC = "asc"
DESC = "desc"


class SortingBase(ABC, BaseModel):
order_by: str
sort_order: SortOrder | None = Query(
SortOrder.ASC,
description="Sort results ascending or descending. Default ASC",
examples=["sort=desc"],
)


# Thinking about how the paging should be done
# we should not let folks pass an offset if we also include
# a page parameter. And until pydantic supports computed
Expand Down Expand Up @@ -846,3 +768,117 @@ def where(self) -> str | None:

class MeasurementsQueries(Paging, ParametersQuery):
...


class QueryBuilder(object):
"""A utility class to wrap multiple QueryBaseModel classes"""

def __init__(self, query: type):
"""
Args:
query: a class which inherits from one or more pydantic query
models, QueryBaseModel.
"""
self.query = query
self.sort_field = False

def _bases(self) -> list[type]:
"""inspects the object and returns base classes
Removes primitive objects in ancestry to only include Pydantic Query
and Path models
Returns:
a sorted list of base classes
"""
base_classes = inspect.getmro(self.query.__class__)
bases = [
x for x in base_classes if not ABC in x.__bases__
] # remove all abstract classes
bases.remove(object) # remove <class 'object'>
bases.remove(ABC) # <class 'ABC'>
bases.remove(BaseModel) # <class 'pydantic.main.BaseModel'>
bases_sorted = sorted(
bases, key=operator.attrgetter("__name__")
) # sort to ensure consistent order for reliability in testing
return bases_sorted

@property
def _sortable(self) -> SortingBase | None:
base_classes = inspect.getmro(self.query.__class__)
sort_class = [x for x in base_classes if issubclass(x, SortingBase)]
if len(sort_class) > 0:
sort_class.remove(self.query.__class__)
sort_class.remove(SortingBase)
return sort_class[0]
else:
return None

def fields(self) -> str:
"""
loops through all ancestor classes and calls
their respective fields() methods to concatenate
into additional fields for select
Returns:
"""
fields = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "fields", None)):
if base.fields(self.query):
fields.append(base.fields(self.query))
if len(fields):
fields = list(set(fields))
return "\n," + ("\n,").join(fields)
else:
return ""

def pagination(self) -> str:
pagination = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "pagination", None)):
if base.pagination(self.query):
pagination.append(base.pagination(self.query))
if len(pagination):
pagination = list(set(pagination))
return "\n" + ("\n,").join(pagination)
else:
return ""

def params(self) -> dict:
return self.query.model_dump(exclude_unset=True, by_alias=True)

@staticmethod
def total() -> str:
"""Generates the SQL for the count of total records found.
Returns:
SQL string for the count of total records found
"""
return ", COUNT(1) OVER() as found"

def where(self) -> str:
"""Introspects object ancestors and calls respective where() methods.
Returns:
SQL string of all ancestor WHERE clauses.
"""
where = []
bases = self._bases()
for base in bases:
if callable(getattr(base, "where", None)):
if base.where(self.query):
where.append(base.where(self.query))
if len(where):
where = list(set(where))
where.sort() # ensure the order is consistent for testing
return "WHERE " + ("\nAND ").join(where)
else:
return ""

def order_by(self) -> str | None:
if self._sortable:
return f"ORDER BY {self.query.order_by.lower()} {self.query.sort_order.upper()}"
23 changes: 16 additions & 7 deletions openaq_api/openaq_api/v3/routers/countries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import StrEnum, auto
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, Path
from fastapi import APIRouter, Depends, Path, Query

from openaq_api.db import DB
from openaq_api.v3.models.queries import (
Expand All @@ -10,6 +11,7 @@
ProviderQuery,
QueryBaseModel,
QueryBuilder,
SortingBase,
)
from openaq_api.v3.models.responses import CountriesResponse

Expand Down Expand Up @@ -47,12 +49,19 @@ def where(self) -> str:
return "id = :countries_id"


## TODO
class CountriesQueries(
Paging,
ParametersQuery,
ProviderQuery,
):
class CountriesSortFields(StrEnum):
ID = auto()


class CountriesSorting(SortingBase):
order_by: CountriesSortFields | None = Query(
"id",
description="The field by which to order results",
examples=["order_by=id"],
)


class CountriesQueries(Paging, ParametersQuery, ProviderQuery, CountriesSorting):
...


Expand Down
Loading

0 comments on commit 75a2efd

Please sign in to comment.