Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #154 from lendingblock/master
Browse files Browse the repository at this point in the history
1.2.0
  • Loading branch information
lsbardel authored Jan 11, 2019
2 parents eaad314 + 02839a1 commit afb36a6
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 71 deletions.
2 changes: 1 addition & 1 deletion openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Minimal OpenAPI asynchronous server application
"""

__version__ = '1.1.3'
__version__ = '1.2.0'
10 changes: 2 additions & 8 deletions openapi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def web(self):
"""Return the web application
"""
if self._web is None:
app = web.Application(
debug=get_debug_flag()
)
app = web.Application()
app['cli'] = self
app['spec'] = self.spec
app['cwd'] = os.getcwd()
Expand All @@ -61,9 +59,7 @@ def web(self):
def get_serve_app(self):
app = self.web()
if self.base_path:
base = web.Application(
debug=get_debug_flag()
)
base = web.Application()
base.add_subapp(self.base_path, app)
app = base
return app
Expand Down Expand Up @@ -106,7 +102,5 @@ def serve(ctx, host, port, reload):
"""Run the aiohttp server.
"""
app = ctx.obj['app']['cli'].get_serve_app()
if reload is None and app.debug:
reload = True
access_log = getLogger()
web.run_app(app, host=host, port=port, access_log=access_log)
51 changes: 2 additions & 49 deletions openapi/db/dbmodel.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
from sqlalchemy import or_
from sqlalchemy.sql import and_, Select
from sqlalchemy.sql import and_

from ..db.container import Database
from .compile import compile_query
from ..spec.pagination import DEF_PAGINATION_LIMIT


class CrudDB(Database):

@classmethod
def get_order_clause(cls, table, query, order_by, order_desc):
if not order_by:
return query

order_by_column = getattr(table.c, order_by)
if order_desc:
order_by_column = order_by_column.desc()
return query.order_by(order_by_column)

@classmethod
def get_search_clause(cls, table, query, search, search_columns):
if not search:
return query

columns = [getattr(table.c, col) for col in search_columns]
return query.where(
or_(
*(col.ilike(f'%{search}%') for col in columns)
)
)

async def db_select(self, table, filters, *, conn=None, consumer=None):
query = self.get_query(table, table.select(), consumer, filters)
sql, args = compile_query(query)
Expand Down Expand Up @@ -57,12 +32,7 @@ def get_query(self, table, query, consumer=None, params=None):
filters = []
columns = table.c
params = params or {}
limit = params.pop('limit', DEF_PAGINATION_LIMIT)
offset = params.pop('offset', 0)
order_by = params.pop('order_by', None)
order_desc = params.pop('order_desc', False)
search = params.pop('search', None)
search_columns = params.pop('search_fields', [])

for key, value in params.items():
bits = key.split(':')
field = bits[0]
Expand All @@ -80,23 +50,6 @@ def get_query(self, table, query, consumer=None, params=None):
if filters:
filters = and_(*filters) if len(filters) > 1 else filters[0]
query = query.where(filters)

if isinstance(query, Select):
# ordering
query = self.get_order_clause(table, query, order_by, order_desc)

# pagination
query = query.offset(offset)
query = query.limit(limit)

# search
query = self.get_search_clause(
table,
query,
search,
search_columns
)

return query

def default_filter_field(self, field, op, value):
Expand Down
58 changes: 57 additions & 1 deletion openapi/db/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from aiohttp import web
from asyncpg.exceptions import UniqueViolationError
from sqlalchemy.sql import or_

from .compile import compile_query
from ..db.dbmodel import CrudDB
from ..spec.path import ApiPath
from ..spec.pagination import Pagination, DEF_PAGINATION_LIMIT

unique_regex = re.compile(r'Key \((?P<column>(\w+,? ?)+)\)=\((?P<value>.+)\)')

Expand All @@ -27,6 +29,27 @@ def db(self) -> CrudDB:
def db_table(self):
return self.db.metadata.tables[self.table]

def get_search_clause(self, table, query, search, search_columns):
if not search:
return query

columns = [getattr(table.c, col) for col in search_columns]
return query.where(
or_(
*(col.ilike(f'%{search}%') for col in columns)
)
)

def get_special_params(self, params):
return dict(
limit=params.pop('limit', DEF_PAGINATION_LIMIT),
offset=params.pop('offset', 0),
order_by=params.pop('order_by', None),
order_desc=params.pop('order_desc', False),
search=params.pop('search', None),
search_columns=params.pop('search_fields', []),
)

async def get_list(
self, *, filters=None, query=None, table=None,
query_schema='query_schema', dump_schema='response_schema',
Expand All @@ -37,12 +60,45 @@ async def get_list(
table = table if table is not None else self.db_table
if not filters:
filters = self.get_filters(query=query, query_schema=query_schema)
specials = self.get_special_params(filters)
query = self.db.get_query(table, table.select(), self, filters)
#
query_count = query.alias('inner').count()
#
# order by
if specials['order_by']:
order_by_column = getattr(table.c, specials['order_by'], None)
if order_by_column is not None:
if specials['order_desc']:
order_by_column = order_by_column.desc()
query = query.order_by(order_by_column)

# search
query = self.get_search_clause(
table,
query,
specials['search'],
specials['search_columns']
)

# pagination
offset = specials['offset']
limit = specials['limit']
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)

sql, args = compile_query(query)
sql_count, args_count = compile_query(query_count)
async with self.db.ensure_connection(conn) as conn:
total = await conn.fetchrow(sql_count, *args_count)
values = await conn.fetch(sql, *args)
return self.dump(dump_schema, values)
pagination = Pagination(self.request.url)
data = self.dump(dump_schema, values)
return pagination.paginated(
data, total['tbl_row_count'], offset, limit
)

async def create_one(
self, *, data=None, table=None, body_schema='body_schema',
Expand Down
46 changes: 39 additions & 7 deletions openapi/spec/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,52 @@

from multidict import MultiDict

from aiohttp import web

from openapi.json import dumps


MAX_PAGINATION_LIMIT = int(os.environ.get('MAX_PAGINATION_LIMIT') or 100)
DEF_PAGINATION_LIMIT = int(os.environ.get('DEF_PAGINATION_LIMIT') or 50)


class PaginatedData:

def __init__(self, data, pagination, total, offset, limit):
self.data = data
self.pagination = pagination
self.total = total
self.offset = offset
self.limit = limit

def json_response(self, headers=None, **kwargs):
headers = headers or {}
links = self.header_links()
if links:
headers['Link'] = links
return web.json_response(
self.data, headers=headers, **kwargs, dumps=dumps
)

def header_links(self):
links = self.pagination.links(self.total, self.limit, self.offset)
return ', '.join(
f'<{value}> rel="{name}"' for name, value in links.items()
)


class Pagination:
def __init__(self, url):
self.url = url
self.query = MultiDict(url.query)

def paginated(self, data, total, offset, limit):
return PaginatedData(data, self, total, offset, limit)

def first_link(self, total, limit, offset):
n = self._count_part(offset, limit, 0)
if n:
offset -= n*limit
offset -= n * limit
if offset > 0:
return self.link(0, min(limit, offset))

Expand All @@ -33,7 +65,7 @@ def next_link(self, total, limit, offset):
def last_link(self, total, limit, offset):
n = self._count_part(total, limit, offset)
if n > 0:
return self.link(offset + n*limit, limit)
return self.link(offset + n * limit, limit)

def link(self, offset, limit):
query = self.query.copy()
Expand All @@ -43,7 +75,7 @@ def link(self, offset, limit):
def _count_part(self, total, limit, offset):
n = (total - offset) // limit
# make sure we account for perfect matching
if n*limit + offset == total:
if n * limit + offset == total:
n -= 1
return max(0, n)

Expand All @@ -55,10 +87,10 @@ def links(self, total, limit, offset):
prev = self.prev_link(total, limit, offset)
if prev != first:
links['prev'] = prev
next = self.next_link(total, limit, offset)
if next:
next_ = self.next_link(total, limit, offset)
if next_:
last = self.last_link(total, limit, offset)
if last != next:
links['next'] = next
if last != next_:
links['next'] = next_
links['last'] = last
return links
9 changes: 4 additions & 5 deletions tests/example/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ async def get(self):
200:
description: Authenticated tasks
"""
data = await self.get_list()
return self.json_response(data)
paginated = await self.get_list()
return paginated.json_response()

@op(response_schema=Task, body_schema=TaskAdd)
async def post(self):
Expand Down Expand Up @@ -209,9 +209,8 @@ async def get(self):
200:
description: Authenticated tasks
"""
async with self.db.transaction() as conn:
data = await self.get_list(conn=conn)
return self.json_response(data=data)
paginated = await self.get_list()
return paginated.json_response()


@routes.view('/transaction/tasks/{id}')
Expand Down
55 changes: 55 additions & 0 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from yarl import URL
from openapi.testing import jsonBody
from openapi.spec.pagination import Pagination


Expand Down Expand Up @@ -29,3 +30,57 @@ def test_last_link():
assert links['prev'].query['offset'] == '25'
assert links['next'].query['offset'] == '75'
assert links['last'].query['offset'] == '100'


async def test_pagination_next_link(cli):
response = await cli.post('/tasks', json=dict(title='bla'))
await jsonBody(response, 201)
response = await cli.post('/tasks', json=dict(title='foo'))
await jsonBody(response, 201)
response = await cli.get('/tasks')
data = await jsonBody(response)
assert 'Link' not in response.headers
assert len(data) == 2


async def test_pagination_first_link(cli):
response = await cli.post('/tasks', json=dict(title='bla'))
await jsonBody(response, 201)
response = await cli.post('/tasks', json=dict(title='foo'))
await jsonBody(response, 201)
response = await cli.get(
'/tasks',
params={'limit': 10, 'offset': 20}
)
url = response.url
data = await jsonBody(response)
link = response.headers['Link']
assert link == (
f'<{url.parent}{url.path}?limit=10&offset=0> rel="first", '
f'<{url.parent}{url.path}?limit=10&offset=10> rel="prev"'
)
assert 'Link' in response.headers
assert len(data) == 0


async def test_invalid_limit_offset(cli):
response = await cli.get(
'/tasks',
params={'limit': 'wtf'}
)
await jsonBody(response, 422)
response = await cli.get(
'/tasks',
params={'limit': 0}
)
await jsonBody(response, 422)
response = await cli.get(
'/tasks',
params={'offset': 'wtf'}
)
await jsonBody(response, 422)
response = await cli.get(
'/tasks',
params={'offset': -10}
)
await jsonBody(response, 422)

0 comments on commit afb36a6

Please sign in to comment.