Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #47 from lendingblock/master
Browse files Browse the repository at this point in the history
0.3.1
  • Loading branch information
lsbardel authored Jul 14, 2018
2 parents 3ca187c + 2c3232b commit 816d365
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 69 deletions.
2 changes: 1 addition & 1 deletion openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Minimal OpenAPI asynchronous server application
"""
__version__ = '0.3.0'
__version__ = '0.3.1'
147 changes: 86 additions & 61 deletions openapi/spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import List
from typing import List, Dict

from aiohttp import hdrs
from aiohttp import web
from dataclasses import dataclass, asdict, is_dataclass
from dataclasses import dataclass, asdict, is_dataclass, field

from .exceptions import InvalidTypeException
from .path import ApiPath
Expand All @@ -16,6 +16,7 @@

OPENAPI = '3.0.1'
METHODS = [method.lower() for method in hdrs.METH_ALL]
SCHEMAS_TO_SCHEMA = ('response_schema', 'body_schema')
SCHEMA_BASE_REF = '#/components/schemas/'


Expand All @@ -38,6 +39,7 @@ class OpenApi:
description: str = ''
version: str = '0.1.0'
termsOfService: str = ''
security: Dict[str, Dict] = field(default_factory=dict)
contact: Contact = Contact()
license: License = License()

Expand All @@ -53,19 +55,23 @@ class SchemaParser:
Decimal: {'type': 'number'}
}

parsed_schemas = {}

def __init__(self, schemas_to_parse):
self._schemas_to_parse = set(schemas_to_parse)

def parse(self):
for schema in self._schemas_to_parse:
if schema.__name__ in self.parsed_schemas:
continue

parsed_schema = self._schema2json(schema)
self.parsed_schemas[schema.__name__] = parsed_schema
return self.parsed_schemas
def __init__(self, group=None):
self.group = group or SchemaGroup()

def parameters(self, Schema, default_in='path'):
params = []
schema = self.schema2json(Schema)
required = set(schema['required'])
for name, entry in schema['properties'].items():
entry = compact(
name=name,
description=entry.pop('description', None),
schema=entry,
required=name in required
)
entry['in'] = default_in
params.append(entry)
return params

def field2json(self, field):
field = fields.as_field(field)
Expand All @@ -78,7 +84,7 @@ def field2json(self, field):
elif is_subclass(field.type, List):
return self._list2json(field.type)
elif is_dataclass(field.type):
return self._get_schema_ref(field.type)
return self.get_schema_ref(field.type)
else:
raise InvalidTypeException(field.type)

Expand All @@ -97,7 +103,7 @@ def field2json(self, field):
validator.openapi(json_property)
return json_property

def _schema2json(self, schema):
def schema2json(self, schema):
properties = {}
required = []
for item in schema.__dataclass_fields__.values():
Expand All @@ -117,10 +123,10 @@ def _schema2json(self, schema):
'additionalProperties': False
}

def _get_schema_ref(self, schema):
if schema not in self.parsed_schemas:
parsed_schema = self._schema2json(schema)
self.parsed_schemas[schema.__name__] = parsed_schema
def get_schema_ref(self, schema):
if schema not in self.group.parsed_schemas:
parsed_schema = self.schema2json(schema)
self.group.parsed_schemas[schema.__name__] = parsed_schema

return {'$ref': SCHEMA_BASE_REF + schema.__name__}

Expand All @@ -131,6 +137,21 @@ def _list2json(self, field_type):
}


class SchemaGroup:

def __init__(self):
self.parsed_schemas = {}

def parse(self, schemas):
for schema in set(schemas):
if schema.__name__ in self.parsed_schemas:
continue

parsed_schema = SchemaParser(self).schema2json(schema)
self.parsed_schemas[schema.__name__] = parsed_schema
return self.parsed_schemas


class OpenApiSpec:
"""Open API document builder
"""
Expand Down Expand Up @@ -160,9 +181,13 @@ def build(self, app):
"""
self.logger = app.logger
self.schemas_to_parse.add(app['exc_schema'])
security = self.doc['info'].get('security')
sk = {}
if security:
sk = security
self.doc['info']['security'] = list(sk)
self._build_paths(app)
schemas_parser = SchemaParser(self.schemas_to_parse)
self.schemas = schemas_parser.parse()
self.schemas = SchemaGroup().parse(self.schemas_to_parse)
s = self.schemas
p = self.parameters
r = self.responses
Expand All @@ -173,6 +198,7 @@ def build(self, app):
schemas=OrderedDict(((k, s[k]) for k in sorted(s))),
parameters=OrderedDict(((k, p[k]) for k in sorted(p))),
responses=OrderedDict(((k, r[k]) for k in sorted(r))),
securitySchemes=OrderedDict((((k, sk[k]) for k in sorted(sk))))
),
servers=self.servers
))
Expand All @@ -194,6 +220,9 @@ def _build_paths(self, app):
def _build_path_object(self, handler, path_obj):
path_obj = load_yaml_from_docstring(handler.__doc__) or {}
tags = self._extend_tags(path_obj.pop('tags', None))
if handler.path_schema:
p = SchemaParser()
path_obj['parameters'] = p.parameters(handler.path_schema)
for method in METHODS:
method_handler = getattr(handler, method, None)
if method_handler is None:
Expand All @@ -207,40 +236,38 @@ def _build_path_object(self, handler, path_obj):
continue

method_doc = load_yaml_from_docstring(method_handler.__doc__) or {}
if method_doc.pop('private', False):
continue
mtags = tags.copy()
mtags.update(self._extend_tags(method_doc.pop('tags', None)))
op_attrs = asdict(operation)
self._add_schemas_from_operation(op_attrs)
responses = self._get_resonse_object(op_attrs, method_doc)
request_body = self._get_request_body_object(op_attrs, method_doc)

self._get_response_object(op_attrs, method_doc)
self._get_request_body_object(op_attrs, method_doc)
self._get_query_parameters(op_attrs, method_doc)
method_doc['tags'] = list(mtags)
path_obj[method] = method_doc

if responses is not None:
path_obj[method]['responses'] = responses

if request_body is not None:
path_obj[method]['requestBody'] = request_body

return path_obj

def _get_resonse_object(self, op_attrs, method_doc):
def _get_schema_info(self, schema):
info = {}
if type(schema) == list:
info['type'] = 'array'
info['items'] = {
'$ref': f'{SCHEMA_BASE_REF}{schema[0].__name__}'
}
elif schema is not None:
info['$ref'] = f'{SCHEMA_BASE_REF}{schema.__name__}'
return info

def _get_response_object(self, op_attrs, doc):
response_schema = op_attrs.get('response_schema', None)
if response_schema is None:
return None

schema = {}
if type(response_schema) == list:
schema['type'] = 'array'
schema['items'] = {
'$ref': SCHEMA_BASE_REF + response_schema[0].__name__
}
elif response_schema is not None:
schema['$ref'] = SCHEMA_BASE_REF + response_schema.__name__

schema = self._get_schema_info(response_schema)
responses = {}
for response, data in method_doc.get('responses', {}).items():
for response, data in doc.get('responses', {}).items():
responses[response] = {
'description': data['description'],
'content': {
Expand All @@ -249,28 +276,26 @@ def _get_resonse_object(self, op_attrs, method_doc):
}
}
}
return responses

def _get_request_body_object(self, op_attrs, method_doc):
body_schema = op_attrs.get('body_schema', None)
if body_schema is None:
return
doc['responses'] = responses

if type(body_schema) == list:
body_schema = body_schema[0]

return {
'description': method_doc.get('body', {}).get('summary', ''),
'content': {
'application/json': {
'schema': SCHEMA_BASE_REF + body_schema.__name__
def _get_request_body_object(self, op_attrs, doc):
schema = self._get_schema_info(op_attrs.get('body_schema', None))
if schema:
doc['requestBody'] = {
'content': {
'application/json': {
'schema': schema
}
}
}
}

def _get_query_parameters(self, op_attrs, doc):
schema = op_attrs.get('query_schema', None)
if schema:
doc['parameters'] = SchemaParser().parameters(schema, 'query')

def _add_schemas_from_operation(self, operation_obj):
schemas = ['response_schema', 'body_schema', 'query_schema']
for schema in schemas:
for schema in SCHEMAS_TO_SCHEMA:
schema_obj = operation_obj[schema]
if schema_obj is not None:
if type(schema_obj) == list:
Expand Down
2 changes: 1 addition & 1 deletion openapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NO_DEBUG = {'0', 'false', 'no'}


class _AsyncGeneratorContextManager:
class _AsyncGeneratorContextManager: # pragma: no cover
def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
self.func, self.args, self.kwds = func, args, kwds
Expand Down
10 changes: 5 additions & 5 deletions tests/spec/test_schema_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_get_schema_ref():
class MyClass:
str_field: str

parser = SchemaParser([MyClass])
parser = SchemaParser()

schema_ref = parser._get_schema_ref(MyClass)
schema_ref = parser.get_schema_ref(MyClass)
assert schema_ref == {'$ref': '#/components/schemas/MyClass'}
assert 'MyClass' in parser.parsed_schemas.keys()
assert 'MyClass' in parser.group.parsed_schemas.keys()


def test_schema2json():
Expand All @@ -40,8 +40,8 @@ class MyClass:
ref_field: OtherClass = field(metadata={'required': True})
list_ref_field: List[OtherClass]

parser = SchemaParser([])
schema_json = parser._schema2json(MyClass)
parser = SchemaParser()
schema_json = parser.schema2json(MyClass)
expected = {
'type': 'object',
'description': 'Test data',
Expand Down
17 changes: 17 additions & 0 deletions tests/spec/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,20 @@ async def test_spec_validation(test_app):
spec = OpenApiSpec(asdict(open_api))
spec.build(test_app)
# validate_spec(spec.doc)


async def test_spec_security(test_app):
open_api = OpenApi(
security=dict(
auth_key={
'type': 'apiKey',
'name': 'X-Api-Key',
'description': 'The authentication key',
'in': 'header'
}
)
)
spec = OpenApiSpec(asdict(open_api))
spec.build(test_app)
assert spec.doc['info']['security'] == ['auth_key']
assert spec.doc['components']['securitySchemes']
4 changes: 3 additions & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ async def test_spec(test_app):

spec = OpenApiSpec(asdict(open_api))
spec.build(test_app)
assert spec.schemas['TaskQuery']['properties'].keys() == {
query = spec.paths['/tasks']['get']['parameters']
filters = [q['name'] for q in query]
assert set(filters) == {
'done',
'severity',
'severity:lt',
Expand Down

0 comments on commit 816d365

Please sign in to comment.