Skip to content

fix!: support $ref from endpoint response to components/responses #1148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,20 @@
}
}
},
"/responses/reference": {
"get": {
"tags": [
"responses"
],
"summary": "Endpoint using predefined response",
"operationId": "reference_response",
"responses": {
"200": {
"$ref": "#/components/responses/AResponse"
}
}
}
},
"/auth/token_with_cookie": {
"get": {
"tags": [
Expand Down Expand Up @@ -2971,6 +2985,18 @@
}
}
}
},
"responses": {
"AResponse": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AModel"
}
}
}
}
}
}
}
21 changes: 21 additions & 0 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,20 @@ info:
}
}
},
"/responses/reference": {
"get": {
"tags": [
"responses"
],
"summary": "Endpoint using predefined response",
"operationId": "reference_response",
"responses": {
"200": {
"$ref": "#/components/responses/AResponse"
}
}
}
},
"/auth/token_with_cookie": {
"get": {
"tags": [
Expand Down Expand Up @@ -2962,3 +2976,10 @@ info:
"application/json":
"schema":
"$ref": "#/components/schemas/AModel"
responses:
AResponse:
description: OK
content:
"application/json":
"schema":
"$ref": "#/components/schemas/AModel"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import types

from . import post_responses_unions_simple_before_complex, text_response
from . import post_responses_unions_simple_before_complex, reference_response, text_response


class ResponsesEndpoints:
Expand All @@ -19,3 +19,10 @@ def text_response(cls) -> types.ModuleType:
Text Response
"""
return text_response

@classmethod
def reference_response(cls) -> types.ModuleType:
"""
Endpoint using predefined response
"""
return reference_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from http import HTTPStatus
from typing import Any, Optional, Union

import httpx

from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.a_model import AModel
from ...types import Response


def _get_kwargs() -> dict[str, Any]:
_kwargs: dict[str, Any] = {
"method": "get",
"url": "/responses/reference",
}

return _kwargs


def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[AModel]:
if response.status_code == 200:
response_200 = AModel.from_dict(response.json())

return response_200
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None


def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[AModel]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)


def sync_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[AModel]:
"""Endpoint using predefined response

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[AModel]
"""

kwargs = _get_kwargs()

response = client.get_httpx_client().request(
**kwargs,
)

return _build_response(client=client, response=response)


def sync(
*,
client: Union[AuthenticatedClient, Client],
) -> Optional[AModel]:
"""Endpoint using predefined response

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
AModel
"""

return sync_detailed(
client=client,
).parsed


async def asyncio_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[AModel]:
"""Endpoint using predefined response

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[AModel]
"""

kwargs = _get_kwargs()

response = await client.get_async_httpx_client().request(**kwargs)

return _build_response(client=client, response=response)


async def asyncio(
*,
client: Union[AuthenticatedClient, Client],
) -> Optional[AModel]:
"""Endpoint using predefined response

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
AModel
"""

return (
await asyncio_detailed(
client=client,
)
).parsed
3 changes: 2 additions & 1 deletion openapi_python_client/parser/bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Schemas,
property_from_data,
)
from openapi_python_client.parser.properties.schemas import get_reference_simple_name

from .. import schema as oai
from ..config import Config
Expand Down Expand Up @@ -138,7 +139,7 @@ def _resolve_reference(
references_seen = []
while isinstance(body, oai.Reference) and body.ref not in references_seen:
references_seen.append(body.ref)
body = request_bodies.get(body.ref.split("/")[-1])
body = request_bodies.get(get_reference_simple_name(body.ref))
if isinstance(body, oai.Reference):
return ParseError(detail="Circular $ref in request body", data=body)
if body is None and references_seen:
Expand Down
27 changes: 24 additions & 3 deletions openapi_python_client/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def from_data(
schemas: Schemas,
parameters: Parameters,
request_bodies: dict[str, Union[oai.RequestBody, oai.Reference]],
responses: dict[str, Union[oai.Response, oai.Reference]],
config: Config,
) -> tuple[dict[utils.PythonIdentifier, "EndpointCollection"], Schemas, Parameters]:
"""Parse the openapi paths data to get EndpointCollections by tag"""
Expand All @@ -73,6 +74,7 @@ def from_data(
schemas=schemas,
parameters=parameters,
request_bodies=request_bodies,
responses=responses,
config=config,
)
# Add `PathItem` parameters
Expand Down Expand Up @@ -145,7 +147,12 @@ class Endpoint:

@staticmethod
def _add_responses(
*, endpoint: "Endpoint", data: oai.Responses, schemas: Schemas, config: Config
*,
endpoint: "Endpoint",
data: oai.Responses,
schemas: Schemas,
responses: dict[str, Union[oai.Response, oai.Reference]],
config: Config,
) -> tuple["Endpoint", Schemas]:
endpoint = deepcopy(endpoint)
for code, response_data in data.items():
Expand All @@ -168,6 +175,7 @@ def _add_responses(
status_code=status_code,
data=response_data,
schemas=schemas,
responses=responses,
parent_name=endpoint.name,
config=config,
)
Expand Down Expand Up @@ -397,6 +405,7 @@ def from_data(
schemas: Schemas,
parameters: Parameters,
request_bodies: dict[str, Union[oai.RequestBody, oai.Reference]],
responses: dict[str, Union[oai.Response, oai.Reference]],
config: Config,
) -> tuple[Union["Endpoint", ParseError], Schemas, Parameters]:
"""Construct an endpoint from the OpenAPI data"""
Expand Down Expand Up @@ -425,7 +434,13 @@ def from_data(
)
if isinstance(result, ParseError):
return result, schemas, parameters
result, schemas = Endpoint._add_responses(endpoint=result, data=data.responses, schemas=schemas, config=config)
result, schemas = Endpoint._add_responses(
endpoint=result,
data=data.responses,
schemas=schemas,
responses=responses,
config=config,
)
if isinstance(result, ParseError):
return result, schemas, parameters
bodies, schemas = body_from_data(
Expand Down Expand Up @@ -515,8 +530,14 @@ def from_dict(data: dict[str, Any], *, config: Config) -> Union["GeneratorData",
config=config,
)
request_bodies = (openapi.components and openapi.components.requestBodies) or {}
responses = (openapi.components and openapi.components.responses) or {}
endpoint_collections_by_tag, schemas, parameters = EndpointCollection.from_data(
data=openapi.paths, schemas=schemas, parameters=parameters, request_bodies=request_bodies, config=config
data=openapi.paths,
schemas=schemas,
parameters=parameters,
request_bodies=request_bodies,
responses=responses,
config=config,
)

enums = (
Expand Down
9 changes: 8 additions & 1 deletion openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]:
return cast(ReferencePath, parsed.fragment)


def get_reference_simple_name(ref_path: str) -> str:
"""
Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`.
"""
return ref_path.split("/", 3)[-1]


@define
class Class:
"""Represents Python class which will be generated from an OpenAPI schema"""
Expand All @@ -56,7 +63,7 @@ class Class:
@staticmethod
def from_string(*, string: str, config: Config) -> "Class":
"""Get a Class from an arbitrary string"""
class_name = string.split("/")[-1] # Get rid of ref path stuff
class_name = get_reference_simple_name(string) # Get rid of ref path stuff
class_name = ClassName(class_name, config.field_prefix)
override = config.class_overrides.get(class_name)

Expand Down
24 changes: 14 additions & 10 deletions openapi_python_client/parser/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from attrs import define

from openapi_python_client import utils
from openapi_python_client.parser.properties.schemas import get_reference_simple_name, parse_reference_path

from .. import Config
from .. import schema as oai
Expand Down Expand Up @@ -79,27 +80,30 @@ def empty_response(
)


def response_from_data(
def response_from_data( # noqa: PLR0911
*,
status_code: HTTPStatus,
data: Union[oai.Response, oai.Reference],
schemas: Schemas,
responses: dict[str, Union[oai.Response, oai.Reference]],
parent_name: str,
config: Config,
) -> tuple[Union[Response, ParseError], Schemas]:
"""Generate a Response from the OpenAPI dictionary representation of it"""

response_name = f"response_{status_code}"
if isinstance(data, oai.Reference):
return (
empty_response(
status_code=status_code,
response_name=response_name,
config=config,
data=data,
),
schemas,
)
ref_path = parse_reference_path(data.ref)
if isinstance(ref_path, ParseError):
return ref_path, schemas
if not ref_path.startswith("/components/responses/"):
return ParseError(data=data, detail=f"$ref to {data.ref} not allowed in responses"), schemas
resp_data = responses.get(get_reference_simple_name(ref_path), None)
if not resp_data:
return ParseError(data=data, detail=f"Could not find reference: {data.ref}"), schemas
if not isinstance(resp_data, oai.Response):
return ParseError(data=data, detail="Top-level $ref inside components/responses is not supported"), schemas
data = resp_data

content = data.content
if not content:
Expand Down
Loading
Loading