diff --git a/openapi/__init__.py b/openapi/__init__.py index d6aa6bc..5726dbf 100644 --- a/openapi/__init__.py +++ b/openapi/__init__.py @@ -1,4 +1,4 @@ """Minimal OpenAPI asynchronous server application """ -__version__ = '0.7.7' +__version__ = '0.8.0' diff --git a/openapi/rest.py b/openapi/rest.py index 0038878..a0cba20 100644 --- a/openapi/rest.py +++ b/openapi/rest.py @@ -1,19 +1,33 @@ from dataclasses import dataclass +import typing from .data.fields import Choice, IntegerValidator from .cli import OpenApiClient from .data.fields import data_field, bool_field -from .spec import OpenApi +from .spec import OpenApi, OpenApiSpec from .spec.utils import docjoin from .spec.pagination import MAX_PAGINATION_LIMIT -def rest(setup_app=None, base_path=None, commands=None, **kwargs): +def rest( + openapi: dict=None, + setup_app: object=None, + base_path: str=None, + commands: typing.List=None, + allowed_tags: typing.Set=None, + validate_docs: bool=False +): """Create the OpenApi application server """ - spec = OpenApi(**kwargs) return OpenApiClient( - spec, base_path=base_path, commands=commands, setup_app=setup_app + OpenApiSpec( + OpenApi(**(openapi or {})), + allowed_tags=allowed_tags, + validate_docs=validate_docs + ), + base_path=base_path, + commands=commands, + setup_app=setup_app ) diff --git a/openapi/spec/spec.py b/openapi/spec/spec.py index 17a8388..a88a6d2 100644 --- a/openapi/spec/spec.py +++ b/openapi/spec/spec.py @@ -2,11 +2,11 @@ from datetime import datetime, date from decimal import Decimal from enum import Enum -from typing import List, Dict +from typing import List, Dict, Iterable +from dataclasses import dataclass, asdict, is_dataclass, field from aiohttp import hdrs from aiohttp import web -from dataclasses import dataclass, asdict, is_dataclass, field from .exceptions import InvalidTypeException, InvalidSpecException from .path import ApiPath @@ -59,8 +59,9 @@ class SchemaParser: Decimal: {'type': 'number'} } - def __init__(self, group=None): + def __init__(self, group=None, validate_docs=False): self.group = group or SchemaGroup() + self.validate_docs = validate_docs def parameters(self, Schema, default_in='path'): params = [] @@ -77,7 +78,7 @@ def parameters(self, Schema, default_in='path'): params.append(entry) return params - def field2json(self, field, validate_info=True): + def field2json(self, field, validate_docs=True): field = fields.as_field(field) mapping = self._fields_mapping.get(field.type, None) if not mapping: @@ -100,9 +101,9 @@ def field2json(self, field, validate_info=True): meta = field.metadata field_description = meta.get(fields.DESCRIPTION) if not field_description: - if validate_info: + if validate_docs and self.validate_docs: raise InvalidSpecException( - f'Missing description for field {field.name}' + f'Missing description for field "{field.name}"' ) else: json_property['description'] = field_description @@ -167,12 +168,14 @@ class SchemaGroup: def __init__(self): self.parsed_schemas = {} - def parse(self, schemas): + def parse(self, schemas, validate_docs=False): for schema in set(schemas): if schema.__name__ in self.parsed_schemas: continue - parsed_schema = SchemaParser(self).schema2json(schema) + parsed_schema = SchemaParser( + self, validate_docs=validate_docs + ).schema2json(schema) self.parsed_schemas[schema.__name__] = parsed_schema return self.parsed_schemas @@ -180,8 +183,13 @@ def parse(self, schemas): class OpenApiSpec: """Open API document builder """ - def __init__(self, info, default_content_type=None, - default_responses=None, allowed_tags=None): + def __init__( + self, + info: OpenApi=None, + default_content_type: str=None, + default_responses: Iterable=None, + allowed_tags: Iterable=None, + validate_docs: bool=False): self.schemas = {} self.parameters = {} self.responses = {} @@ -192,16 +200,25 @@ def __init__(self, info, default_content_type=None, self.default_responses = default_responses or {} self.doc = dict( openapi=OPENAPI, - info=info, + info=asdict(info or OpenApi()), paths=OrderedDict() ) self.schemas_to_parse = set() self.allowed_tags = allowed_tags + self.validate_docs = validate_docs @property def paths(self): return self.doc['paths'] + @property + def title(self): + return self.doc['info']['title'] + + @property + def version(self): + return self.doc['info']['version'] + def build(self, app, public=True, private=False): """Build the ``doc`` dictionary by adding paths """ @@ -230,7 +247,7 @@ def build(self, app, public=True, private=False): ), servers=self.servers )) - return self + return doc def _build_paths(self, app, public, private): """Loop through app paths and add @@ -244,30 +261,37 @@ def _build_paths(self, app, public, private): handler = route.handler if (issubclass(handler, ApiPath) and self._include(handler.private, public, private)): - paths[path] = self._build_path_object( - handler, app, public, private - ) - - self._validate_tags() + try: + paths[path] = self._build_path_object( + handler, app, public, private + ) + except InvalidSpecException as exc: + raise InvalidSpecException( + f'Invalid spec in route "{path}": {exc}' + ) from None + + if self.validate_docs: + self._validate_tags() def _validate_tags(self): for tag_name, tag_obj in self.tags.items(): if self.allowed_tags and tag_name not in self.allowed_tags: - raise InvalidSpecException(f'Tag {tag_name} not allowed') + raise InvalidSpecException(f'Tag "{tag_name}" not allowed') if 'description' not in tag_obj: raise InvalidSpecException( - f'Missing tag {tag_name} description' + f'Missing tag "{tag_name}" description' ) def _build_path_object(self, handler, path_obj, public, private): path_obj = load_yaml_from_docstring(handler.__doc__) or {} doc_tags = path_obj.pop('tags', None) - if not doc_tags: - raise InvalidSpecException(f'Missing tags docstring for {handler}') + if not doc_tags and self.validate_docs: + raise InvalidSpecException( + f'Missing tags docstring for "{handler}"') tags = self._extend_tags(doc_tags) if handler.path_schema: - p = SchemaParser() + p = SchemaParser(validate_docs=self.validate_docs) path_obj['parameters'] = p.parameters(handler.path_schema) for method in METHODS: method_handler = getattr(handler, method, None) @@ -311,16 +335,17 @@ def _get_schema_info(self, schema): return info def _get_method_info(self, method_handler, method_doc): - summary = method_doc.get('summary') - if not summary: - raise InvalidSpecException( - f'Missing method summary for {method_handler}' - ) - description = method_doc.get('description') - if not description: - raise InvalidSpecException( - f'Missing method description for {method_handler}' - ) + summary = method_doc.get('summary', '') + description = method_doc.get('description', '') + if self.validate_docs: + if not summary: + raise InvalidSpecException( + f'Missing method summary for "{method_handler}"' + ) + if not description: + raise InvalidSpecException( + f'Missing method description for "{method_handler}"' + ) return {'summary': summary, 'description': description} def _get_response_object(self, op_attrs, doc): @@ -359,7 +384,8 @@ def _get_request_body_object(self, op_attrs, doc): def _get_query_parameters(self, op_attrs, doc): schema = op_attrs.get('query_schema', None) if schema: - doc['parameters'] = SchemaParser().parameters(schema, 'query') + doc['parameters'] = SchemaParser( + validate_docs=self.validate_docs).parameters(schema, 'query') def _add_schemas_from_operation(self, operation_obj): for schema in SCHEMAS_TO_SCHEMA: @@ -396,6 +422,5 @@ async def spec_root(request): app = request.app spec = app.get('spec_doc') if not spec: - spec = OpenApiSpec(asdict(app['spec'])) - app['spec_doc'] = spec.build(app) - return web.json_response(spec.doc) + app['spec_doc'] = app['spec'].build(app) + return web.json_response(app['spec_doc']) diff --git a/tests/spec/test_schema_parser.py b/tests/spec/test_schema_parser.py index bd0b4fc..12abdeb 100644 --- a/tests/spec/test_schema_parser.py +++ b/tests/spec/test_schema_parser.py @@ -108,11 +108,11 @@ class MyClass: def test_field2json(): parser = SchemaParser([]) - str_json = parser.field2json(str, validate_info=False) - int_json = parser.field2json(int, validate_info=False) - float_json = parser.field2json(float, validate_info=False) - bool_json = parser.field2json(bool, validate_info=False) - datetime_json = parser.field2json(datetime, validate_info=False) + str_json = parser.field2json(str) + int_json = parser.field2json(int) + float_json = parser.field2json(float) + bool_json = parser.field2json(bool) + datetime_json = parser.field2json(datetime) assert str_json == {'type': 'string'} assert int_json == {'type': 'integer', 'format': 'int32'} @@ -123,12 +123,8 @@ def test_field2json(): def test_field2json_format(): parser = SchemaParser([]) - str_json = parser.field2json( - as_field(str, format='uuid'), validate_info=False - ) - int_json = parser.field2json( - as_field(int, format='int64'), validate_info=False - ) + str_json = parser.field2json(as_field(str, format='uuid')) + int_json = parser.field2json(as_field(int, format='int64')) assert str_json == {'type': 'string', 'format': 'uuid'} assert int_json == {'type': 'integer', 'format': 'int64'} @@ -149,7 +145,7 @@ class MyClass: desc_field: str = data_field(description='Valid field') no_desc_field: str = data_field() - parser = SchemaParser() + parser = SchemaParser(validate_docs=True) with pytest.raises(InvalidSpecException): parser.get_schema_ref(MyClass) @@ -161,7 +157,7 @@ class MyEnum(Enum): FIELD_3 = 2 parser = SchemaParser([]) - json_type = parser.field2json(MyEnum, validate_info=False) + json_type = parser.field2json(MyEnum) assert json_type == { 'type': 'string', 'enum': ['FIELD_1', 'FIELD_2', 'FIELD_3'] } diff --git a/tests/spec/test_spec.py b/tests/spec/test_spec.py index 66284e2..101c5ab 100644 --- a/tests/spec/test_spec.py +++ b/tests/spec/test_spec.py @@ -1,11 +1,11 @@ import pytest -from dataclasses import asdict + from openapi.rest import rest +from openapi.spec import OpenApi, OpenApiSpec +from openapi.spec.exceptions import InvalidSpecException # from openapi_spec_validator import validate_spec -from openapi.spec import OpenApi, OpenApiSpec -from openapi.spec.exceptions import InvalidSpecException from ..example import endpoints @@ -24,9 +24,7 @@ def test_init(): async def test_spec_validation(test_app): - open_api = OpenApi() - - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec() spec.build(test_app) # validate_spec(spec.doc) @@ -42,15 +40,14 @@ async def test_spec_security(test_app): } ) ) - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec(open_api) spec.build(test_app) assert spec.doc['info']['security'] == ['auth_key'] assert spec.doc['components']['securitySchemes'] async def test_spec_422(test_app): - open_api = OpenApi() - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec() spec.build(test_app) tasks = spec.doc['paths']['/tasks'] resp = tasks['post']['responses'] @@ -62,8 +59,7 @@ async def test_spec_422(test_app): async def test_invalid_path(): app = create_spec_app(endpoints.invalid_path_routes) - open_api = OpenApi() - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec(validate_docs=True) with pytest.raises(InvalidSpecException): spec.build(app) @@ -71,8 +67,7 @@ async def test_invalid_path(): async def test_invalid_method_missing_summary(): app = create_spec_app(endpoints.invalid_method_summary_routes) - open_api = OpenApi() - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec(validate_docs=True) with pytest.raises(InvalidSpecException): spec.build(app) @@ -80,8 +75,7 @@ async def test_invalid_method_missing_summary(): async def test_invalid_method_missing_description(): app = create_spec_app(endpoints.invalid_method_description_routes) - open_api = OpenApi() - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec(validate_docs=True) with pytest.raises(InvalidSpecException): spec.build(app) @@ -89,9 +83,7 @@ async def test_invalid_method_missing_description(): async def test_allowed_tags_ok(): app = create_spec_app(endpoints.routes) - open_api = OpenApi() spec = OpenApiSpec( - asdict(open_api), allowed_tags=set(('Task', 'Transaction', 'Random')) ) spec.build(app) @@ -99,9 +91,8 @@ async def test_allowed_tags_ok(): async def test_allowed_tags_invalid(): app = create_spec_app(endpoints.routes) - open_api = OpenApi() spec = OpenApiSpec( - asdict(open_api), + validate_docs=True, allowed_tags=set(('Task', 'Transaction')) ) with pytest.raises(InvalidSpecException): @@ -110,9 +101,8 @@ async def test_allowed_tags_invalid(): async def test_tags_missing_description(): app = create_spec_app(endpoints.invalid_tag_missing_description_routes) - open_api = OpenApi() spec = OpenApiSpec( - asdict(open_api), + validate_docs=True, allowed_tags=set(('Task', 'Transaction', 'Random')) ) with pytest.raises(InvalidSpecException): diff --git a/tests/test_filters.py b/tests/test_filters.py index 1d83755..70b0f90 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,10 +1,8 @@ -from dataclasses import asdict - import pytest from multidict import MultiDict -from openapi.spec import OpenApi, OpenApiSpec +from openapi.spec import OpenApiSpec from openapi.testing import jsonBody @@ -43,9 +41,7 @@ async def assert_query(cli, params, expected): async def test_spec(test_app): - open_api = OpenApi() - - spec = OpenApiSpec(asdict(open_api)) + spec = OpenApiSpec() spec.build(test_app) query = spec.paths['/tasks']['get']['parameters'] filters = [q['name'] for q in query]