diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index dbd0beeb..4fa8cd80 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -71,10 +71,14 @@ def __init__( h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded - | h2.events.StreamReset, + | h2.events.StreamReset + | h2.events.TrailersReceived, ], ] = {} + # Mapping from stream ID to trailing headers + self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {} + # Connection terminated events are stored as state since # we need to handle them for all streams. self._connection_terminated: h2.events.ConnectionTerminated | None = None @@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response: ) trace.return_value = (status, headers) + extensions = { + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + } + return Response( status=status, headers=headers, - content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), - extensions={ - "http_version": b"HTTP/2", - "network_stream": self._network_stream, - "stream_id": stream_id, - }, + content=HTTP2ConnectionByteStream( + connection=self, + request=request, + stream_id=stream_id, + extensions=extensions, + ), + extensions=extensions, ) except BaseException as exc: # noqa: PIE786 with AsyncShieldCancellation(): @@ -326,7 +337,12 @@ async def _receive_response_body( async def _receive_stream_event( self, request: Request, stream_id: int - ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded: + ) -> ( + h2.events.ResponseReceived + | h2.events.DataReceived + | h2.events.StreamEnded + | h2.events.TrailersReceived + ): """ Return the next available event for a given stream ID. @@ -337,6 +353,13 @@ async def _receive_stream_event( event = self._events[stream_id].pop(0) if isinstance(event, h2.events.StreamReset): raise RemoteProtocolError(event) + elif isinstance(event, h2.events.TrailersReceived): + if event.stream_id in self._events and event.headers is not None: + self._trailing_headers[event.stream_id] = [] + for k, v in event.headers: + if not k.startswith(b":"): + self._trailing_headers[event.stream_id].append((k, v)) + return event async def _receive_events( @@ -377,6 +400,7 @@ async def _receive_events( h2.events.DataReceived, h2.events.StreamEnded, h2.events.StreamReset, + h2.events.TrailersReceived, ), ): if event.stream_id in self._events: @@ -409,6 +433,8 @@ async def _receive_remote_settings_change( async def _response_closed(self, stream_id: int) -> None: await self._max_streams_semaphore.release() del self._events[stream_id] + if stream_id in self._trailing_headers: + del self._trailing_headers[stream_id] async with self._state_lock: if self._connection_terminated and not self._events: await self.aclose() @@ -561,12 +587,17 @@ async def __aexit__( class HTTP2ConnectionByteStream: def __init__( - self, connection: AsyncHTTP2Connection, request: Request, stream_id: int + self, + connection: AsyncHTTP2Connection, + request: Request, + stream_id: int, + extensions: typing.MutableMapping[str, typing.Any], ) -> None: self._connection = connection self._request = request self._stream_id = stream_id self._closed = False + self._extensions = extensions async def __aiter__(self) -> typing.AsyncIterator[bytes]: kwargs = {"request": self._request, "stream_id": self._stream_id} @@ -576,6 +607,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: request=self._request, stream_id=self._stream_id ): yield chunk + + if self._stream_id in self._connection._trailing_headers: + self._extensions["trailing_headers"] = ( + self._connection._trailing_headers[self._stream_id] + ) except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index ddcc1890..7ddd409d 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -71,10 +71,14 @@ def __init__( h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded - | h2.events.StreamReset, + | h2.events.StreamReset + | h2.events.TrailersReceived, ], ] = {} + # Mapping from stream ID to trailing headers + self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {} + # Connection terminated events are stored as state since # we need to handle them for all streams. self._connection_terminated: h2.events.ConnectionTerminated | None = None @@ -152,15 +156,22 @@ def handle_request(self, request: Request) -> Response: ) trace.return_value = (status, headers) + extensions = { + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + } + return Response( status=status, headers=headers, - content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), - extensions={ - "http_version": b"HTTP/2", - "network_stream": self._network_stream, - "stream_id": stream_id, - }, + content=HTTP2ConnectionByteStream( + connection=self, + request=request, + stream_id=stream_id, + extensions=extensions, + ), + extensions=extensions, ) except BaseException as exc: # noqa: PIE786 with ShieldCancellation(): @@ -326,7 +337,12 @@ def _receive_response_body( def _receive_stream_event( self, request: Request, stream_id: int - ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded: + ) -> ( + h2.events.ResponseReceived + | h2.events.DataReceived + | h2.events.StreamEnded + | h2.events.TrailersReceived + ): """ Return the next available event for a given stream ID. @@ -337,6 +353,13 @@ def _receive_stream_event( event = self._events[stream_id].pop(0) if isinstance(event, h2.events.StreamReset): raise RemoteProtocolError(event) + elif isinstance(event, h2.events.TrailersReceived): + if event.stream_id in self._events and event.headers is not None: + self._trailing_headers[event.stream_id] = [] + for k, v in event.headers: + if not k.startswith(b":"): + self._trailing_headers[event.stream_id].append((k, v)) + return event def _receive_events( @@ -377,6 +400,7 @@ def _receive_events( h2.events.DataReceived, h2.events.StreamEnded, h2.events.StreamReset, + h2.events.TrailersReceived, ), ): if event.stream_id in self._events: @@ -409,6 +433,8 @@ def _receive_remote_settings_change( def _response_closed(self, stream_id: int) -> None: self._max_streams_semaphore.release() del self._events[stream_id] + if stream_id in self._trailing_headers: + del self._trailing_headers[stream_id] with self._state_lock: if self._connection_terminated and not self._events: self.close() @@ -561,12 +587,17 @@ def __exit__( class HTTP2ConnectionByteStream: def __init__( - self, connection: HTTP2Connection, request: Request, stream_id: int + self, + connection: HTTP2Connection, + request: Request, + stream_id: int, + extensions: typing.MutableMapping[str, typing.Any], ) -> None: self._connection = connection self._request = request self._stream_id = stream_id self._closed = False + self._extensions = extensions def __iter__(self) -> typing.Iterator[bytes]: kwargs = {"request": self._request, "stream_id": self._stream_id} @@ -576,6 +607,11 @@ def __iter__(self) -> typing.Iterator[bytes]: request=self._request, stream_id=self._stream_id ): yield chunk + + if self._stream_id in self._connection._trailing_headers: + self._extensions["trailing_headers"] = ( + self._connection._trailing_headers[self._stream_id] + ) except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/tests/_async/test_http2_trailing_headers.py b/tests/_async/test_http2_trailing_headers.py new file mode 100644 index 00000000..6dc900db --- /dev/null +++ b/tests/_async/test_http2_trailing_headers.py @@ -0,0 +1,156 @@ +import hpack +import hyperframe.frame +import pytest + +import httpcore + + +@pytest.mark.anyio +async def test_http2_connection_with_trailing_headers(): + """ + Test that trailing headers are correctly received and processed. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + # Send trailing headers + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + async with httpcore.AsyncHTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + # Check that trailing headers are included in extensions + assert "trailing_headers" in response.extensions + assert response.extensions["trailing_headers"] == [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + + +@pytest.mark.anyio +async def test_http2_connection_with_body_and_trailing_headers(): + """ + Test that trailing headers are correctly received and processed + when reading the response body in chunks. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, ").serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"world!").serialize(), + # Send trailing headers + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + + async with httpcore.AsyncHTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + async with conn.stream("GET", "https://example.com/") as response: + content = b"" + async for chunk in response.aiter_stream(): + content += chunk + + assert response.status == 200 + assert content == b"Hello, world!" + + # Check that trailing headers are included in extensions + assert "trailing_headers" in response.extensions + assert response.extensions["trailing_headers"] == [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + + +@pytest.mark.anyio +async def test_http2_connection_with_trailing_headers_pseudo_removed(): + """ + Test that pseudo-headers in trailing headers are correctly filtered out. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + # Send trailing headers with a pseudo-header which should be filtered out + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":pseudo", b"should-be-filtered"), + (b"x-trailer", b"trailer-value"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + async with httpcore.AsyncHTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + # Check that trailing headers are included in extensions but pseudo-headers are filtered + assert "trailing_headers" in response.extensions + assert len(response.extensions["trailing_headers"]) == 1 + assert response.extensions["trailing_headers"] == [ + (b"x-trailer", b"trailer-value"), + ] diff --git a/tests/_sync/test_http2_trailing_headers.py b/tests/_sync/test_http2_trailing_headers.py new file mode 100644 index 00000000..7760ca1a --- /dev/null +++ b/tests/_sync/test_http2_trailing_headers.py @@ -0,0 +1,156 @@ +import hpack +import hyperframe.frame +import pytest + +import httpcore + + + +def test_http2_connection_with_trailing_headers(): + """ + Test that trailing headers are correctly received and processed. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + # Send trailing headers + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + with httpcore.HTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + # Check that trailing headers are included in extensions + assert "trailing_headers" in response.extensions + assert response.extensions["trailing_headers"] == [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + + + +def test_http2_connection_with_body_and_trailing_headers(): + """ + Test that trailing headers are correctly received and processed + when reading the response body in chunks. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, ").serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"world!").serialize(), + # Send trailing headers + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + + with httpcore.HTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + with conn.stream("GET", "https://example.com/") as response: + content = b"" + for chunk in response.iter_stream(): + content += chunk + + assert response.status == 200 + assert content == b"Hello, world!" + + # Check that trailing headers are included in extensions + assert "trailing_headers" in response.extensions + assert response.extensions["trailing_headers"] == [ + (b"x-trailer-1", b"trailer-value-1"), + (b"x-trailer-2", b"trailer-value-2"), + ] + + + +def test_http2_connection_with_trailing_headers_pseudo_removed(): + """ + Test that pseudo-headers in trailing headers are correctly filtered out. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(), + # Send trailing headers with a pseudo-header which should be filtered out + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":pseudo", b"should-be-filtered"), + (b"x-trailer", b"trailer-value"), + ] + ), + flags=["END_HEADERS", "END_STREAM"], + ).serialize(), + ] + ) + with httpcore.HTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + # Check that trailing headers are included in extensions but pseudo-headers are filtered + assert "trailing_headers" in response.extensions + assert len(response.extensions["trailing_headers"]) == 1 + assert response.extensions["trailing_headers"] == [ + (b"x-trailer", b"trailer-value"), + ]