diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index 22a786a4f..e78c348c5 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -991,6 +991,20 @@ } } }, + "/responses/reference": { + "get": { + "tags": [ + "responses" + ], + "summary": "Endpoint using predefined response", + "operationId": "reference_response", + "responses": { + "200": { + "$ref": "#/components/responses/AResponse" + } + } + } + }, "/auth/token_with_cookie": { "get": { "tags": [ @@ -2971,6 +2985,18 @@ } } } + }, + "responses": { + "AResponse": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AModel" + } + } + } + } } } } diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index a19e46ce3..919dda986 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -983,6 +983,20 @@ info: } } }, + "/responses/reference": { + "get": { + "tags": [ + "responses" + ], + "summary": "Endpoint using predefined response", + "operationId": "reference_response", + "responses": { + "200": { + "$ref": "#/components/responses/AResponse" + } + } + } + }, "/auth/token_with_cookie": { "get": { "tags": [ @@ -2962,3 +2976,10 @@ info: "application/json": "schema": "$ref": "#/components/schemas/AModel" + responses: + AResponse: + description: OK + content: + "application/json": + "schema": + "$ref": "#/components/schemas/AModel" diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py index 6000bd0e7..e09dee3e3 100644 --- a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/responses/__init__.py @@ -2,7 +2,7 @@ import types -from . import post_responses_unions_simple_before_complex, text_response +from . import post_responses_unions_simple_before_complex, reference_response, text_response class ResponsesEndpoints: @@ -19,3 +19,10 @@ def text_response(cls) -> types.ModuleType: Text Response """ return text_response + + @classmethod + def reference_response(cls) -> types.ModuleType: + """ + Endpoint using predefined response + """ + return reference_response diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/responses/reference_response.py b/end_to_end_tests/golden-record/my_test_api_client/api/responses/reference_response.py new file mode 100644 index 000000000..ac71e9e50 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/api/responses/reference_response.py @@ -0,0 +1,122 @@ +from http import HTTPStatus +from typing import Any, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.a_model import AModel +from ...types import Response + + +def _get_kwargs() -> dict[str, Any]: + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/responses/reference", + } + + return _kwargs + + +def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[AModel]: + if response.status_code == 200: + response_200 = AModel.from_dict(response.json()) + + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[AModel]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[AModel]: + """Endpoint using predefined response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AModel] + """ + + kwargs = _get_kwargs() + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[AModel]: + """Endpoint using predefined response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AModel + """ + + return sync_detailed( + client=client, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[AModel]: + """Endpoint using predefined response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[AModel] + """ + + kwargs = _get_kwargs() + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[AModel]: + """Endpoint using predefined response + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + AModel + """ + + return ( + await asyncio_detailed( + client=client, + ) + ).parsed diff --git a/openapi_python_client/parser/bodies.py b/openapi_python_client/parser/bodies.py index c51966412..7d0b12954 100644 --- a/openapi_python_client/parser/bodies.py +++ b/openapi_python_client/parser/bodies.py @@ -9,6 +9,7 @@ Schemas, property_from_data, ) +from openapi_python_client.parser.properties.schemas import get_reference_simple_name from .. import schema as oai from ..config import Config @@ -138,7 +139,7 @@ def _resolve_reference( references_seen = [] while isinstance(body, oai.Reference) and body.ref not in references_seen: references_seen.append(body.ref) - body = request_bodies.get(body.ref.split("/")[-1]) + body = request_bodies.get(get_reference_simple_name(body.ref)) if isinstance(body, oai.Reference): return ParseError(detail="Circular $ref in request body", data=body) if body is None and references_seen: diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 43e63c434..f0210e0fa 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -51,6 +51,7 @@ def from_data( schemas: Schemas, parameters: Parameters, request_bodies: dict[str, Union[oai.RequestBody, oai.Reference]], + responses: dict[str, Union[oai.Response, oai.Reference]], config: Config, ) -> tuple[dict[utils.PythonIdentifier, "EndpointCollection"], Schemas, Parameters]: """Parse the openapi paths data to get EndpointCollections by tag""" @@ -73,6 +74,7 @@ def from_data( schemas=schemas, parameters=parameters, request_bodies=request_bodies, + responses=responses, config=config, ) # Add `PathItem` parameters @@ -145,7 +147,12 @@ class Endpoint: @staticmethod def _add_responses( - *, endpoint: "Endpoint", data: oai.Responses, schemas: Schemas, config: Config + *, + endpoint: "Endpoint", + data: oai.Responses, + schemas: Schemas, + responses: dict[str, Union[oai.Response, oai.Reference]], + config: Config, ) -> tuple["Endpoint", Schemas]: endpoint = deepcopy(endpoint) for code, response_data in data.items(): @@ -168,6 +175,7 @@ def _add_responses( status_code=status_code, data=response_data, schemas=schemas, + responses=responses, parent_name=endpoint.name, config=config, ) @@ -397,6 +405,7 @@ def from_data( schemas: Schemas, parameters: Parameters, request_bodies: dict[str, Union[oai.RequestBody, oai.Reference]], + responses: dict[str, Union[oai.Response, oai.Reference]], config: Config, ) -> tuple[Union["Endpoint", ParseError], Schemas, Parameters]: """Construct an endpoint from the OpenAPI data""" @@ -425,7 +434,13 @@ def from_data( ) if isinstance(result, ParseError): return result, schemas, parameters - result, schemas = Endpoint._add_responses(endpoint=result, data=data.responses, schemas=schemas, config=config) + result, schemas = Endpoint._add_responses( + endpoint=result, + data=data.responses, + schemas=schemas, + responses=responses, + config=config, + ) if isinstance(result, ParseError): return result, schemas, parameters bodies, schemas = body_from_data( @@ -515,8 +530,14 @@ def from_dict(data: dict[str, Any], *, config: Config) -> Union["GeneratorData", config=config, ) request_bodies = (openapi.components and openapi.components.requestBodies) or {} + responses = (openapi.components and openapi.components.responses) or {} endpoint_collections_by_tag, schemas, parameters = EndpointCollection.from_data( - data=openapi.paths, schemas=schemas, parameters=parameters, request_bodies=request_bodies, config=config + data=openapi.paths, + schemas=schemas, + parameters=parameters, + request_bodies=request_bodies, + responses=responses, + config=config, ) enums = ( diff --git a/openapi_python_client/parser/properties/schemas.py b/openapi_python_client/parser/properties/schemas.py index 177a86924..3114cae3b 100644 --- a/openapi_python_client/parser/properties/schemas.py +++ b/openapi_python_client/parser/properties/schemas.py @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]: return cast(ReferencePath, parsed.fragment) +def get_reference_simple_name(ref_path: str) -> str: + """ + Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`. + """ + return ref_path.split("/", 3)[-1] + + @define class Class: """Represents Python class which will be generated from an OpenAPI schema""" @@ -56,7 +63,7 @@ class Class: @staticmethod def from_string(*, string: str, config: Config) -> "Class": """Get a Class from an arbitrary string""" - class_name = string.split("/")[-1] # Get rid of ref path stuff + class_name = get_reference_simple_name(string) # Get rid of ref path stuff class_name = ClassName(class_name, config.field_prefix) override = config.class_overrides.get(class_name) diff --git a/openapi_python_client/parser/responses.py b/openapi_python_client/parser/responses.py index d313f81ad..ec0f6136b 100644 --- a/openapi_python_client/parser/responses.py +++ b/openapi_python_client/parser/responses.py @@ -6,6 +6,7 @@ from attrs import define from openapi_python_client import utils +from openapi_python_client.parser.properties.schemas import get_reference_simple_name, parse_reference_path from .. import Config from .. import schema as oai @@ -79,11 +80,12 @@ def empty_response( ) -def response_from_data( +def response_from_data( # noqa: PLR0911 *, status_code: HTTPStatus, data: Union[oai.Response, oai.Reference], schemas: Schemas, + responses: dict[str, Union[oai.Response, oai.Reference]], parent_name: str, config: Config, ) -> tuple[Union[Response, ParseError], Schemas]: @@ -91,15 +93,17 @@ def response_from_data( response_name = f"response_{status_code}" if isinstance(data, oai.Reference): - return ( - empty_response( - status_code=status_code, - response_name=response_name, - config=config, - data=data, - ), - schemas, - ) + ref_path = parse_reference_path(data.ref) + if isinstance(ref_path, ParseError): + return ref_path, schemas + if not ref_path.startswith("/components/responses/"): + return ParseError(data=data, detail=f"$ref to {data.ref} not allowed in responses"), schemas + resp_data = responses.get(get_reference_simple_name(ref_path), None) + if not resp_data: + return ParseError(data=data, detail=f"Could not find reference: {data.ref}"), schemas + if not isinstance(resp_data, oai.Response): + return ParseError(data=data, detail="Top-level $ref inside components/responses is not supported"), schemas + data = resp_data content = data.content if not content: diff --git a/tests/test_parser/test_openapi.py b/tests/test_parser/test_openapi.py index 6eeadcd78..7f0f7addf 100644 --- a/tests/test_parser/test_openapi.py +++ b/tests/test_parser/test_openapi.py @@ -85,7 +85,9 @@ def test__add_responses_status_code_error(self, response_status_code, mocker): response_from_data = mocker.patch(f"{MODULE_NAME}.response_from_data", return_value=(parse_error, schemas)) config = MagicMock() - response, schemas = Endpoint._add_responses(endpoint=endpoint, data=data, schemas=schemas, config=config) + response, schemas = Endpoint._add_responses( + endpoint=endpoint, data=data, schemas=schemas, responses={}, config=config + ) assert response.errors == [ ParseError( @@ -110,12 +112,28 @@ def test__add_responses_error(self, mocker): response_from_data = mocker.patch(f"{MODULE_NAME}.response_from_data", return_value=(parse_error, schemas)) config = MagicMock() - response, schemas = Endpoint._add_responses(endpoint=endpoint, data=data, schemas=schemas, config=config) + response, schemas = Endpoint._add_responses( + endpoint=endpoint, data=data, schemas=schemas, responses={}, config=config + ) response_from_data.assert_has_calls( [ - mocker.call(status_code=200, data=response_1_data, schemas=schemas, parent_name="name", config=config), - mocker.call(status_code=404, data=response_2_data, schemas=schemas, parent_name="name", config=config), + mocker.call( + status_code=200, + data=response_1_data, + schemas=schemas, + responses={}, + parent_name="name", + config=config, + ), + mocker.call( + status_code=404, + data=response_2_data, + schemas=schemas, + responses={}, + parent_name="name", + config=config, + ), ] ) assert response.errors == [ @@ -474,6 +492,7 @@ def test_from_data_bad_params(self, mocker, config): method=method, tag="default", schemas=initial_schemas, + responses={}, parameters=parameters, config=config, request_bodies={}, @@ -509,6 +528,7 @@ def test_from_data_bad_responses(self, mocker, config): method=method, tag="default", schemas=initial_schemas, + responses={}, parameters=initial_parameters, config=config, request_bodies={}, @@ -549,6 +569,7 @@ def test_from_data_standard(self, mocker, config): method=method, tag="default", schemas=initial_schemas, + responses={}, parameters=initial_parameters, config=config, request_bodies={}, @@ -570,7 +591,7 @@ def test_from_data_standard(self, mocker, config): config=config, ) _add_responses.assert_called_once_with( - endpoint=param_endpoint, data=data.responses, schemas=param_schemas, config=config + endpoint=param_endpoint, data=data.responses, schemas=param_schemas, responses={}, config=config ) def test_from_data_no_operation_id(self, mocker, config): @@ -600,6 +621,7 @@ def test_from_data_no_operation_id(self, mocker, config): method=method, tag="default", schemas=schemas, + responses={}, parameters=parameters, config=config, request_bodies={}, @@ -624,6 +646,7 @@ def test_from_data_no_operation_id(self, mocker, config): endpoint=add_parameters.return_value[0], data=data.responses, schemas=add_parameters.return_value[1], + responses={}, config=config, ) @@ -654,6 +677,7 @@ def test_from_data_no_security(self, mocker, config): method=method, tag="a", schemas=schemas, + responses={}, parameters=parameters, config=config, request_bodies={}, @@ -678,6 +702,7 @@ def test_from_data_no_security(self, mocker, config): endpoint=add_parameters.return_value[0], data=data.responses, schemas=add_parameters.return_value[1], + responses={}, config=config, ) @@ -693,6 +718,7 @@ def test_from_data_some_bad_bodies(self, config): ), ), schemas=Schemas(), + responses={}, config=config, parameters=Parameters(), tag="tag", @@ -716,6 +742,7 @@ def test_from_data_all_bodies_bad(self, config): ), ), schemas=Schemas(), + responses={}, config=config, parameters=Parameters(), tag="tag", @@ -787,6 +814,7 @@ def test_from_data_overrides_path_item_params_with_operation_params(self, config parameters=Parameters(), config=config, request_bodies={}, + responses={}, ) collection: EndpointCollection = collections["default"] assert isinstance(collection.endpoints[0].query_parameters[0], IntProperty) @@ -825,6 +853,7 @@ def test_from_data_errors(self, mocker, config): config=config, parameters=parameters, request_bodies={}, + responses={}, ) assert result["default"].parse_errors[0].data == "1" @@ -866,7 +895,7 @@ def test_from_data_tags_snake_case_sanitizer(self, mocker, config): parameters = mocker.MagicMock() result = EndpointCollection.from_data( - data=data, schemas=schemas, parameters=parameters, config=config, request_bodies={} + data=data, schemas=schemas, parameters=parameters, config=config, request_bodies={}, responses={} ) assert result == ( diff --git a/tests/test_parser/test_responses.py b/tests/test_parser/test_responses.py index 0ac885764..24fb94c61 100644 --- a/tests/test_parser/test_responses.py +++ b/tests/test_parser/test_responses.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock +import pytest + import openapi_python_client.schema as oai from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas @@ -17,6 +19,7 @@ def test_response_from_data_no_content(any_property_factory): status_code=200, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=MagicMock(), ) @@ -34,31 +37,6 @@ def test_response_from_data_no_content(any_property_factory): ) -def test_response_from_data_reference(any_property_factory): - from openapi_python_client.parser.responses import Response, response_from_data - - data = oai.Reference.model_construct() - - response, schemas = response_from_data( - status_code=200, - data=data, - schemas=Schemas(), - parent_name="parent", - config=MagicMock(), - ) - - assert response == Response( - status_code=200, - prop=any_property_factory( - name="response_200", - default=None, - required=True, - ), - source=NONE_SOURCE, - data=data, - ) - - def test_response_from_data_unsupported_content_type(): from openapi_python_client.parser.responses import response_from_data @@ -69,6 +47,7 @@ def test_response_from_data_unsupported_content_type(): status_code=200, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=config, ) @@ -89,6 +68,7 @@ def test_response_from_data_no_content_schema(any_property_factory): status_code=200, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=config, ) @@ -121,6 +101,7 @@ def test_response_from_data_property_error(mocker): status_code=400, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=config, ) @@ -152,6 +133,7 @@ def test_response_from_data_property(mocker, any_property_factory): status_code=400, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=config, ) @@ -172,6 +154,99 @@ def test_response_from_data_property(mocker, any_property_factory): ) +def test_response_from_data_reference(mocker, any_property_factory): + from openapi_python_client.parser import responses + + prop = any_property_factory() + mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas())) + predefined_response_data = oai.Response.model_construct( + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, + ) + config = MagicMock() + config.content_type_overrides = {} + + response, schemas = responses.response_from_data( + status_code=400, + data=oai.Reference.model_construct(ref="#/components/responses/ErrorResponse"), + schemas=Schemas(), + responses={"ErrorResponse": predefined_response_data}, + parent_name="parent", + config=config, + ) + + assert response == responses.Response( + status_code=400, + prop=prop, + source=JSON_SOURCE, + data=predefined_response_data, + ) + + +@pytest.mark.parametrize( + "ref_string,expected_error_string", + [ + ("#/components/responses/Nonexistent", "Could not find"), + ("https://remote-reference", "Remote references"), + ("#/components/something-that-isnt-responses/ErrorResponse", "not allowed in responses"), + ], +) +def test_response_from_data_invalid_reference(ref_string, expected_error_string, mocker, any_property_factory): + from openapi_python_client.parser import responses + + prop = any_property_factory() + mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas())) + predefined_response_data = oai.Response.model_construct( + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, + ) + config = MagicMock() + config.content_type_overrides = {} + + response, schemas = responses.response_from_data( + status_code=400, + data=oai.Reference.model_construct(ref=ref_string), + schemas=Schemas(), + responses={"ErrorResponse": predefined_response_data}, + parent_name="parent", + config=config, + ) + + assert isinstance(response, ParseError) + assert expected_error_string in response.detail + + +def test_response_from_data_ref_to_response_that_is_a_ref(mocker, any_property_factory): + from openapi_python_client.parser import responses + + prop = any_property_factory() + mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas())) + predefined_response_base_data = oai.Response.model_construct( + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, + ) + predefined_response_data = oai.Reference.model_construct( + ref="#/components/references/BaseResponse", + ) + config = MagicMock() + config.content_type_overrides = {} + + response, schemas = responses.response_from_data( + status_code=400, + data=oai.Reference.model_construct(ref="#/components/responses/ErrorResponse"), + schemas=Schemas(), + responses={ + "BaseResponse": predefined_response_base_data, + "ErrorResponse": predefined_response_data, + }, + parent_name="parent", + config=config, + ) + + assert isinstance(response, ParseError) + assert "Top-level $ref" in response.detail + + def test_response_from_data_content_type_overrides(any_property_factory): from openapi_python_client.parser.responses import Response, response_from_data @@ -185,6 +260,7 @@ def test_response_from_data_content_type_overrides(any_property_factory): status_code=200, data=data, schemas=Schemas(), + responses={}, parent_name="parent", config=config, )